Skip to content

Commit 6329692

Browse files
committed
feat: Use support_point as default init for pymc
1 parent cad8e99 commit 6329692

File tree

3 files changed

+26
-10
lines changed

3 files changed

+26
-10
lines changed

python/nutpie/compile_pymc.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -435,9 +435,12 @@ def compile_pymc_model(
435435
*,
436436
backend: Literal["numba", "jax"] = "numba",
437437
gradient_backend: Literal["pytensor", "jax"] = "pytensor",
438-
overrides: dict[Union["Variable", str], np.ndarray | float | int] | None = None,
438+
initial_points: dict[Union["Variable", str], np.ndarray | float | int]
439+
| None = None,
439440
jitter_rvs: set["TensorVariable"] | None = None,
440-
default_strategy: Literal["support_point", "prior"] = "prior",
441+
default_initialization_strategy: Literal[
442+
"support_point", "prior"
443+
] = "support_point",
441444
**kwargs,
442445
) -> CompiledModel:
443446
"""Compile necessary functions for sampling a pymc model.
@@ -455,10 +458,10 @@ def compile_pymc_model(
455458
The set (or list or tuple) of random variables for which a U(-1, +1)
456459
jitter should be added to the initial value. Only available for
457460
variables that have a transform or real-valued support.
458-
default_strategy : str
461+
default_initialization_strategy : str
459462
Which of { "support_point", "prior" } to prefer if the initval setting
460463
for an RV is None.
461-
overrides : dict
464+
initial_points : dict
462465
Initial value (strategies) to use instead of what's specified in
463466
`Model.initial_values`.
464467
Returns
@@ -475,13 +478,13 @@ def compile_pymc_model(
475478
"and restart your kernel in case you are in an interactive session."
476479
)
477480

478-
if default_strategy == "support_point" and jitter_rvs is None:
481+
if default_initialization_strategy == "support_point" and jitter_rvs is None:
479482
jitter_rvs = set(model.free_RVs)
480483

481484
initial_point_fn = make_initial_point_fn(
482485
model=model,
483-
overrides=overrides,
484-
default_strategy=default_strategy,
486+
overrides=initial_points,
487+
default_strategy=default_initialization_strategy,
485488
jitter_rvs=jitter_rvs,
486489
return_transformed=True,
487490
)

python/nutpie/compiled_pyfunc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
from nutpie import _lib
99
from nutpie.sample import CompiledModel
1010

11-
SeedType = int | float | np.random.Generator | None
11+
SeedType = int
1212

1313

1414
@dataclass(frozen=True)
1515
class PyFuncModel(CompiledModel):
1616
_make_logp_func: Callable
1717
_make_expand_func: Callable
18-
_make_initial_points: Callable[[SeedType], np.ndarray]
18+
_make_initial_points: Callable[[SeedType], np.ndarray] | None
1919
_shared_data: dict[str, Any]
2020
_n_dim: int
2121
_variables: list[_lib.PyVariable]
@@ -73,14 +73,14 @@ def from_pyfunc(
7373
ndim: int,
7474
make_logp_fn: Callable,
7575
make_expand_fn: Callable,
76-
make_initial_point_fn: Callable[[SeedType], np.ndarray],
7776
expanded_dtypes: list[np.dtype],
7877
expanded_shapes: list[tuple[int, ...]],
7978
expanded_names: list[str],
8079
*,
8180
coords: dict[str, Any] | None = None,
8281
dims: dict[str, tuple[str, ...]] | None = None,
8382
shared_data: dict[str, Any] | None = None,
83+
make_initial_point_fn: Callable[[SeedType], np.ndarray] | None,
8484
):
8585
variables = []
8686
for name, shape, dtype in zip(

tests/test_pymc.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,19 @@ def test_pymc_model_float32(backend, gradient_backend):
3838
trace.posterior.a # noqa: B018
3939

4040

41+
@parameterize_backends
42+
def test_pymc_model_no_prior(backend, gradient_backend):
43+
with pm.Model() as model:
44+
a = pm.Flat("a")
45+
pm.Normal("b", mu=a, observed=0.0)
46+
47+
compiled = nutpie.compile_pymc_model(
48+
model, backend=backend, gradient_backend=gradient_backend
49+
)
50+
trace = nutpie.sample(compiled, chains=1)
51+
trace.posterior.a # noqa: B018
52+
53+
4154
@parameterize_backends
4255
def test_blocking(backend, gradient_backend):
4356
with pm.Model() as model:

0 commit comments

Comments
 (0)