Skip to content

Commit 50bb956

Browse files
committed
feat: Add option not to store some deterministics
1 parent c100bda commit 50bb956

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed

python/nutpie/compile_pymc.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from functools import wraps
66
from importlib.util import find_spec
77
from math import prod
8-
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
8+
from typing import TYPE_CHECKING, Any, Callable, Iterable, Literal, Optional, Union
99

1010
import numpy as np
1111
import pandas as pd
@@ -218,6 +218,7 @@ def make_user_data(shared_vars, shared_data):
218218
def _compile_pymc_model_numba(
219219
model: "pm.Model",
220220
pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]],
221+
var_names: Iterable[str] | None = None,
221222
**kwargs,
222223
) -> CompiledPyMCModel:
223224
if find_spec("numba") is None:
@@ -242,6 +243,7 @@ def _compile_pymc_model_numba(
242243
compute_grad=True,
243244
join_expanded=True,
244245
pymc_initial_point_fn=pymc_initial_point_fn,
246+
var_names=var_names,
245247
)
246248

247249
expand_fn = expand_fn_pt.vm.jit_fn
@@ -337,6 +339,7 @@ def _compile_pymc_model_jax(
337339
*,
338340
gradient_backend=None,
339341
pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]],
342+
var_names: Iterable[str] | None = None,
340343
**kwargs,
341344
):
342345
if find_spec("jax") is None:
@@ -366,6 +369,7 @@ def _compile_pymc_model_jax(
366369
compute_grad=gradient_backend == "pytensor",
367370
join_expanded=False,
368371
pymc_initial_point_fn=pymc_initial_point_fn,
372+
var_names=var_names,
369373
)
370374

371375
logp_fn = logp_fn_pt.vm.jit_fn
@@ -441,6 +445,7 @@ def compile_pymc_model(
441445
default_initialization_strategy: Literal[
442446
"support_point", "prior"
443447
] = "support_point",
448+
var_names: Iterable[str] | None = None,
444449
**kwargs,
445450
) -> CompiledModel:
446451
"""Compile necessary functions for sampling a pymc model.
@@ -464,6 +469,8 @@ def compile_pymc_model(
464469
initial_points : dict
465470
Initial value (strategies) to use instead of what's specified in
466471
`Model.initial_values`.
472+
var_names : list[str] | None
473+
A list of variables to store in the trace. If None, store all variables.
467474
Returns
468475
-------
469476
compiled_model : CompiledPyMCModel
@@ -493,13 +500,14 @@ def compile_pymc_model(
493500
if gradient_backend == "jax":
494501
raise ValueError("Gradient backend cannot be jax when using numba backend")
495502
return _compile_pymc_model_numba(
496-
model=model, pymc_initial_point_fn=initial_point_fn, **kwargs
503+
model=model, pymc_initial_point_fn=initial_point_fn, var_names=var_names, **kwargs
497504
)
498505
elif backend.lower() == "jax":
499506
return _compile_pymc_model_jax(
500507
model=model,
501508
gradient_backend=gradient_backend,
502509
pymc_initial_point_fn=initial_point_fn,
510+
var_names=var_names,
503511
**kwargs,
504512
)
505513
else:
@@ -542,6 +550,7 @@ def _make_functions(
542550
compute_grad: bool,
543551
join_expanded: bool,
544552
pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]],
553+
var_names: Iterable[str] | None = None,
545554
) -> tuple[
546555
int,
547556
int,
@@ -568,6 +577,8 @@ def _make_functions(
568577
pymc_initial_point_fn: Callable
569578
Initial point function created by
570579
pymc.initial_point.make_initial_point_fn
580+
var_names:
581+
Names of variables to store in the trace. Defaults to all variables.
571582
572583
Returns
573584
-------
@@ -673,6 +684,10 @@ def _make_functions(
673684
var for var in model.unobserved_value_vars if var.name not in joined_names
674685
]
675686

687+
if var_names is not None:
688+
names = set(var_names)
689+
remaining_rvs = [var for var in remaining_rvs if var.name in names]
690+
676691
all_names = joined_names + remaining_rvs
677692

678693
all_names = joined_names.copy()

tests/test_pymc.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,44 @@ def test_pymc_model_shared(backend, gradient_backend):
193193
nutpie.sample(compiled3, chains=1)
194194

195195

196+
@parameterize_backends
197+
def test_pymc_var_names(backend, gradient_backend):
198+
with pm.Model() as model:
199+
mu = pm.Data("mu", -0.1)
200+
sigma = pm.Data("sigma", np.ones(3))
201+
a = pm.Normal("a", mu=mu, sigma=sigma, shape=3)
202+
203+
b = pm.Deterministic("b", mu * a)
204+
pm.Deterministic("c", mu * b)
205+
206+
compiled = nutpie.compile_pymc_model(
207+
model, backend=backend, gradient_backend=gradient_backend, var_names=None,
208+
)
209+
trace = nutpie.sample(compiled, chains=1, seed=1)
210+
211+
# Check that variables are stored
212+
assert hasattr(trace.posterior, "b")
213+
assert hasattr(trace.posterior, "c")
214+
215+
compiled = nutpie.compile_pymc_model(
216+
model, backend=backend, gradient_backend=gradient_backend, var_names=[],
217+
)
218+
trace = nutpie.sample(compiled, chains=1, seed=1)
219+
220+
# Check that variables are stored
221+
assert not hasattr(trace.posterior, "b")
222+
assert not hasattr(trace.posterior, "c")
223+
224+
compiled = nutpie.compile_pymc_model(
225+
model, backend=backend, gradient_backend=gradient_backend, var_names=["b"],
226+
)
227+
trace = nutpie.sample(compiled, chains=1, seed=1)
228+
229+
# Check that variables are stored
230+
assert hasattr(trace.posterior, "b")
231+
assert not hasattr(trace.posterior, "c")
232+
233+
196234
@pytest.mark.parametrize(
197235
("backend", "gradient_backend"),
198236
[

0 commit comments

Comments
 (0)