55from functools import wraps
66from importlib .util import find_spec
77from math import prod
8- from typing import TYPE_CHECKING , Any , Callable , Literal , Optional , Union
8+ from typing import TYPE_CHECKING , Any , Callable , Iterable , Literal , Optional , Union
99
1010import numpy as np
1111import pandas as pd
@@ -218,6 +218,7 @@ def make_user_data(shared_vars, shared_data):
218218def _compile_pymc_model_numba (
219219 model : "pm.Model" ,
220220 pymc_initial_point_fn : Callable [[SeedType ], dict [str , np .ndarray ]],
221+ var_names : Iterable [str ] | None = None ,
221222 ** kwargs ,
222223) -> CompiledPyMCModel :
223224 if find_spec ("numba" ) is None :
@@ -242,6 +243,7 @@ def _compile_pymc_model_numba(
242243 compute_grad = True ,
243244 join_expanded = True ,
244245 pymc_initial_point_fn = pymc_initial_point_fn ,
246+ var_names = var_names ,
245247 )
246248
247249 expand_fn = expand_fn_pt .vm .jit_fn
@@ -337,6 +339,7 @@ def _compile_pymc_model_jax(
337339 * ,
338340 gradient_backend = None ,
339341 pymc_initial_point_fn : Callable [[SeedType ], dict [str , np .ndarray ]],
342+ var_names : Iterable [str ] | None = None ,
340343 ** kwargs ,
341344):
342345 if find_spec ("jax" ) is None :
@@ -366,6 +369,7 @@ def _compile_pymc_model_jax(
366369 compute_grad = gradient_backend == "pytensor" ,
367370 join_expanded = False ,
368371 pymc_initial_point_fn = pymc_initial_point_fn ,
372+ var_names = var_names ,
369373 )
370374
371375 logp_fn = logp_fn_pt .vm .jit_fn
@@ -441,6 +445,7 @@ def compile_pymc_model(
441445 default_initialization_strategy : Literal [
442446 "support_point" , "prior"
443447 ] = "support_point" ,
448+ var_names : Iterable [str ] | None = None ,
444449 ** kwargs ,
445450) -> CompiledModel :
446451 """Compile necessary functions for sampling a pymc model.
@@ -464,6 +469,8 @@ def compile_pymc_model(
464469 initial_points : dict
465470 Initial value (strategies) to use instead of what's specified in
466471 `Model.initial_values`.
472+ var_names : list[str] | None
473+ A list of variables to store in the trace. If None, store all variables.
467474 Returns
468475 -------
469476 compiled_model : CompiledPyMCModel
@@ -493,13 +500,14 @@ def compile_pymc_model(
493500 if gradient_backend == "jax" :
494501 raise ValueError ("Gradient backend cannot be jax when using numba backend" )
495502 return _compile_pymc_model_numba (
496- model = model , pymc_initial_point_fn = initial_point_fn , ** kwargs
503+ model = model , pymc_initial_point_fn = initial_point_fn , var_names = var_names , ** kwargs
497504 )
498505 elif backend .lower () == "jax" :
499506 return _compile_pymc_model_jax (
500507 model = model ,
501508 gradient_backend = gradient_backend ,
502509 pymc_initial_point_fn = initial_point_fn ,
510+ var_names = var_names ,
503511 ** kwargs ,
504512 )
505513 else :
@@ -542,6 +550,7 @@ def _make_functions(
542550 compute_grad : bool ,
543551 join_expanded : bool ,
544552 pymc_initial_point_fn : Callable [[SeedType ], dict [str , np .ndarray ]],
553+ var_names : Iterable [str ] | None = None ,
545554) -> tuple [
546555 int ,
547556 int ,
@@ -568,6 +577,8 @@ def _make_functions(
568577 pymc_initial_point_fn: Callable
569578 Initial point function created by
570579 pymc.initial_point.make_initial_point_fn
580+ var_names:
581+ Names of variables to store in the trace. Defaults to all variables.
571582
572583 Returns
573584 -------
@@ -673,6 +684,10 @@ def _make_functions(
673684 var for var in model .unobserved_value_vars if var .name not in joined_names
674685 ]
675686
687+ if var_names is not None :
688+ names = set (var_names )
689+ remaining_rvs = [var for var in remaining_rvs if var .name in names ]
690+
676691 all_names = joined_names + remaining_rvs
677692
678693 all_names = joined_names .copy ()
0 commit comments