Skip to content

Commit 51b0454

Browse files
committed
feat: Use support_point as default init for pymc
1 parent cad8e99 commit 51b0454

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

python/nutpie/compile_pymc.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -435,9 +435,9 @@ def compile_pymc_model(
435435
*,
436436
backend: Literal["numba", "jax"] = "numba",
437437
gradient_backend: Literal["pytensor", "jax"] = "pytensor",
438-
overrides: dict[Union["Variable", str], np.ndarray | float | int] | None = None,
438+
initial_points: dict[Union["Variable", str], np.ndarray | float | int] | None = None,
439439
jitter_rvs: set["TensorVariable"] | None = None,
440-
default_strategy: Literal["support_point", "prior"] = "prior",
440+
default_initialization_strategy: Literal["support_point", "prior"] = "support_point",
441441
**kwargs,
442442
) -> CompiledModel:
443443
"""Compile necessary functions for sampling a pymc model.
@@ -455,10 +455,10 @@ def compile_pymc_model(
455455
The set (or list or tuple) of random variables for which a U(-1, +1)
456456
jitter should be added to the initial value. Only available for
457457
variables that have a transform or real-valued support.
458-
default_strategy : str
458+
default_initialization_strategy : str
459459
Which of { "support_point", "prior" } to prefer if the initval setting
460460
for an RV is None.
461-
overrides : dict
461+
initial_points : dict
462462
Initial value (strategies) to use instead of what's specified in
463463
`Model.initial_values`.
464464
Returns
@@ -475,13 +475,13 @@ def compile_pymc_model(
475475
"and restart your kernel in case you are in an interactive session."
476476
)
477477

478-
if default_strategy == "support_point" and jitter_rvs is None:
478+
if default_initialization_strategy == "support_point" and jitter_rvs is None:
479479
jitter_rvs = set(model.free_RVs)
480480

481481
initial_point_fn = make_initial_point_fn(
482482
model=model,
483-
overrides=overrides,
484-
default_strategy=default_strategy,
483+
overrides=initial_points,
484+
default_strategy=default_initialization_strategy,
485485
jitter_rvs=jitter_rvs,
486486
return_transformed=True,
487487
)

tests/test_pymc.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,19 @@ def test_pymc_model_float32(backend, gradient_backend):
3838
trace.posterior.a # noqa: B018
3939

4040

41+
@parameterize_backends
42+
def test_pymc_model_no_prior(backend, gradient_backend):
43+
with pm.Model() as model:
44+
a = pm.Flat("a")
45+
pm.Normal("b", mu=a, observed=0.)
46+
47+
compiled = nutpie.compile_pymc_model(
48+
model, backend=backend, gradient_backend=gradient_backend
49+
)
50+
trace = nutpie.sample(compiled, chains=1)
51+
trace.posterior.a # noqa: B018
52+
53+
4154
@parameterize_backends
4255
def test_blocking(backend, gradient_backend):
4356
with pm.Model() as model:

0 commit comments

Comments
 (0)