11import logging
2+ import warnings
23
34from collections .abc import Callable , Sequence
45from typing import Any , Literal
@@ -98,6 +99,13 @@ class PyMCStateSpace:
9899 compute the observation errors. If False, these errors are deterministically zero; if True, they are sampled
99100 from a multivariate normal.
100101
102+ mode: str or Mode, optional
103+ Pytensor compile mode, used in auxiliary sampling methods such as ``sample_conditional_posterior`` and
104+ ``forecast``. The mode does **not** effect calls to ``pm.sample``.
105+
106+ Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument
107+ to all sampling methods.
108+
101109 Notes
102110 -----
103111 Based on the statsmodels statespace implementation https://github.com/statsmodels/statsmodels/blob/main/statsmodels/tsa/statespace/representation.py,
@@ -220,6 +228,7 @@ def __init__(
220228 filter_type : str = "standard" ,
221229 verbose : bool = True ,
222230 measurement_error : bool = False ,
231+ mode : str | None = None ,
223232 ):
224233 self ._fit_coords : dict [str , Sequence [str ]] | None = None
225234 self ._fit_dims : dict [str , Sequence [str ]] | None = None
@@ -235,6 +244,7 @@ def __init__(
235244 self .k_states = k_states
236245 self .k_posdef = k_posdef
237246 self .measurement_error = measurement_error
247+ self .mode = mode
238248
239249 # All models contain a state space representation and a Kalman filter
240250 self .ssm = PytensorRepresentation (k_endog , k_states , k_posdef )
@@ -821,6 +831,7 @@ def build_statespace_graph(
821831 cov_jitter : float | None = JITTER_DEFAULT ,
822832 mvn_method : Literal ["cholesky" , "eigh" , "svd" ] = "svd" ,
823833 save_kalman_filter_outputs_in_idata : bool = False ,
834+ mode : str | None = None ,
824835 ) -> None :
825836 """
826837 Given a parameter vector `theta`, constructs the full computational graph describing the state space model and
@@ -874,7 +885,25 @@ def build_statespace_graph(
874885 save_kalman_filter_outputs_in_idata: bool, optional, default=False
875886 If True, Kalman Filter outputs will be saved in the model as deterministics. Useful for debugging, but
876887 should not be necessary for the majority of users.
888+
889+ mode: str, optional
890+ Pytensor mode to use when compiling the graph. This will be saved as a model attribute and used when
891+ compiling sampling functions (e.g. ``sample_conditional_prior``).
892+
893+ .. deprecated:: 0.2.5
894+ The `mode` argument is deprecated and will be removed in a future version. Pass ``mode`` to the
895+ model constructor, or manually specify ``compile_kwargs`` in sampling functions instead.
896+
877897 """
898+ if mode is not None :
899+ warnings .warn (
900+ "The `mode` argument is deprecated and will be removed in a future version. "
901+ "Pass `mode` to the model constructor, or manually specify `compile_kwargs` in sampling functions"
902+ " instead." ,
903+ DeprecationWarning ,
904+ )
905+ self .mode = mode
906+
878907 pm_mod = modelcontext (None )
879908
880909 self ._insert_random_variables ()
@@ -1107,6 +1136,12 @@ def _kalman_filter_outputs_from_dummy_graph(
11071136
11081137 return [x0 , P0 , c , d , T , Z , R , H , Q ], grouped_outputs
11091138
1139+ def _set_default_mode (self , compile_kwargs ):
1140+ mode = compile_kwargs .get ("mode" , self .mode )
1141+ compile_kwargs ["mode" ] = mode
1142+
1143+ return compile_kwargs
1144+
11101145 def _sample_conditional (
11111146 self ,
11121147 idata : InferenceData ,
@@ -1158,6 +1193,9 @@ def _sample_conditional(
11581193 _verify_group (group )
11591194 group_idata = getattr (idata , group )
11601195
1196+ compile_kwargs = kwargs .pop ("compile_kwargs" , {})
1197+ compile_kwargs = self ._set_default_mode (compile_kwargs )
1198+
11611199 with pm .Model (coords = self ._fit_coords ) as forward_model :
11621200 (
11631201 [
@@ -1224,6 +1262,7 @@ def _sample_conditional(
12241262 for suffix in ["" , "_observed" ]
12251263 ],
12261264 random_seed = random_seed ,
1265+ compile_kwargs = compile_kwargs ,
12271266 ** kwargs ,
12281267 )
12291268
@@ -1289,6 +1328,10 @@ def _sample_unconditional(
12891328 the latent state trajectories: `y[t] = Z @ x[t] + nu[t]`, where `nu ~ N(0, H)`.
12901329 """
12911330 _verify_group (group )
1331+
1332+ compile_kwargs = kwargs .pop ("compile_kwargs" , {})
1333+ compile_kwargs = self ._set_default_mode (compile_kwargs )
1334+
12921335 group_idata = getattr (idata , group )
12931336 dims = None
12941337 temp_coords = self ._fit_coords .copy ()
@@ -1347,6 +1390,7 @@ def _sample_unconditional(
13471390 group_idata ,
13481391 var_names = [f"{ group } _latent" , f"{ group } _observed" ],
13491392 random_seed = random_seed ,
1393+ compile_kwargs = compile_kwargs ,
13501394 ** kwargs ,
13511395 )
13521396
@@ -1574,7 +1618,7 @@ def sample_unconditional_posterior(
15741618 )
15751619
15761620 def sample_statespace_matrices (
1577- self , idata , matrix_names : str | list [str ] | None , group : str = "posterior"
1621+ self , idata , matrix_names : str | list [str ] | None , group : str = "posterior" , ** kwargs
15781622 ):
15791623 """
15801624 Draw samples of requested statespace matrices from provided idata
@@ -1591,12 +1635,18 @@ def sample_statespace_matrices(
15911635 group: str, one of "posterior" or "prior"
15921636 Whether to sample from priors or posteriors
15931637
1638+ kwargs:
1639+ Additional keyword arguments are passed to ``pymc.sample_posterior_predictive``
1640+
15941641 Returns
15951642 -------
15961643 idata_matrices: az.InterenceData
15971644 """
15981645 _verify_group (group )
15991646
1647+ compile_kwargs = kwargs .pop ("compile_kwargs" , {})
1648+ compile_kwargs = self ._set_default_mode (compile_kwargs )
1649+
16001650 if matrix_names is None :
16011651 matrix_names = MATRIX_NAMES
16021652 elif isinstance (matrix_names , str ):
@@ -1628,6 +1678,8 @@ def sample_statespace_matrices(
16281678 idata if group == "posterior" else idata .prior ,
16291679 var_names = matrix_names ,
16301680 extend_inferencedata = False ,
1681+ compile_kwargs = compile_kwargs ,
1682+ ** kwargs ,
16311683 )
16321684
16331685 return matrix_idata
@@ -2096,6 +2148,10 @@ def forecast(
20962148 filter_time_dim = TIME_DIM
20972149
20982150 _validate_filter_arg (filter_output )
2151+
2152+ compile_kwargs = kwargs .pop ("compile_kwargs" , {})
2153+ compile_kwargs = self ._set_default_mode (compile_kwargs )
2154+
20992155 time_index = self ._get_fit_time_index ()
21002156
21012157 if start is None and verbose :
@@ -2198,6 +2254,7 @@ def forecast(
21982254 idata ,
21992255 var_names = ["forecast_latent" , "forecast_observed" ],
22002256 random_seed = random_seed ,
2257+ compile_kwargs = compile_kwargs ,
22012258 ** kwargs ,
22022259 )
22032260
@@ -2285,6 +2342,9 @@ def impulse_response_function(
22852342 n_options = sum (x is not None for x in options )
22862343 Q = None # No covariance matrix needed if a trajectory is provided. Will be overwritten later if needed.
22872344
2345+ compile_kwargs = kwargs .pop ("compile_kwargs" , {})
2346+ compile_kwargs = self ._set_default_mode (compile_kwargs )
2347+
22882348 if n_options > 1 :
22892349 raise ValueError ("Specify exactly 0 or 1 of shock_size, shock_cov, or shock_trajectory" )
22902350 elif n_options == 1 :
@@ -2364,6 +2424,7 @@ def irf_step(shock, x, c, T, R):
23642424 idata ,
23652425 var_names = ["irf" ],
23662426 random_seed = random_seed ,
2427+ compile_kwargs = compile_kwargs ,
23672428 ** kwargs ,
23682429 )
23692430
0 commit comments