@@ -960,6 +960,7 @@ def build_statespace_graph(
960960 mvn_method : Literal ["cholesky" , "eigh" , "svd" ] = "svd" ,
961961 save_kalman_filter_outputs_in_idata : bool = False ,
962962 mode : str | None = None ,
963+ vectorize_draws : bool = True ,
963964 ) -> None :
964965 """
965966 Given a parameter vector `theta`, constructs the full computational graph describing the state space model and
@@ -1022,6 +1023,11 @@ def build_statespace_graph(
10221023 The `mode` argument is deprecated and will be removed in a future version. Pass ``mode`` to the
10231024 model constructor, or manually specify ``compile_kwargs`` in sampling functions instead.
10241025
1026+ vectorize_draws : bool, default True
1027+ If True, sample all draws in a single vectorized operation. This is significantly faster but requires
1028+ more memory. It is strongly recommended to keep this True unless the state space is so large that memory
1029+ becomes an issue.
1030+
10251031 """
10261032 if mode is not None :
10271033 warnings .warn (
@@ -1078,6 +1084,7 @@ def build_statespace_graph(
10781084 observed = data ,
10791085 dims = obs_dims ,
10801086 method = mvn_method ,
1087+ vectorize_draws = vectorize_draws ,
10811088 )
10821089
10831090 self ._fit_coords = pm_mod .coords .copy ()
@@ -1271,6 +1278,7 @@ def _sample_conditional(
12711278 random_seed : RandomState | None = None ,
12721279 data : pt .TensorLike | None = None ,
12731280 mvn_method : Literal ["cholesky" , "eigh" , "svd" ] = "svd" ,
1281+ vectorize_draws : bool = True ,
12741282 ** kwargs ,
12751283 ):
12761284 """
@@ -1300,6 +1308,11 @@ def _sample_conditional(
13001308 In general, if your model has measurement error, "cholesky" will be safe to use. Otherwise, "svd" is
13011309 recommended. "eigh" can also be tried if sampling with "svd" is very slow, but it is not as robust as "svd".
13021310
1311+ vectorize_draws : bool, default True
1312+ If True, sample all draws in a single vectorized operation. This is significantly faster but requires
1313+ more memory. It is strongly recommended to keep this True unless the state space is so large that memory
1314+ becomes an issue.
1315+
13031316 kwargs:
13041317 Additional keyword arguments are passed to pymc.sample_posterior_predictive
13051318
@@ -1355,6 +1368,7 @@ def _sample_conditional(
13551368 logp = dummy_ll ,
13561369 dims = state_dims ,
13571370 method = mvn_method ,
1371+ vectorize_draws = vectorize_draws ,
13581372 )
13591373
13601374 obs_mu = d + (Z @ mu [..., None ]).squeeze (- 1 )
@@ -1367,6 +1381,7 @@ def _sample_conditional(
13671381 logp = dummy_ll ,
13681382 dims = obs_dims ,
13691383 method = mvn_method ,
1384+ vectorize_draws = vectorize_draws ,
13701385 )
13711386
13721387 # TODO: Remove this after pm.Flat initial values are fixed
@@ -1523,6 +1538,7 @@ def sample_conditional_prior(
15231538 idata : InferenceData ,
15241539 random_seed : RandomState | None = None ,
15251540 mvn_method : Literal ["cholesky" , "eigh" , "svd" ] = "svd" ,
1541+ vectorize_draws : bool = True ,
15261542 ** kwargs ,
15271543 ) -> InferenceData :
15281544 """
@@ -1547,6 +1563,11 @@ def sample_conditional_prior(
15471563 In general, if your model has measurement error, "cholesky" will be safe to use. Otherwise, "svd" is
15481564 recommended. "eigh" can also be tried if sampling with "svd" is very slow, but it is not as robust as "svd".
15491565
1566+ vectorize_draws : bool, default True
1567+ If True, sample all draws in a single vectorized operation. This is significantly faster but requires
1568+ more memory. It is strongly recommended to keep this True unless the state space is so large that memory
1569+ becomes an issue.
1570+
15501571 kwargs:
15511572 Additional keyword arguments are passed to pymc.sample_posterior_predictive
15521573
@@ -1559,14 +1580,20 @@ def sample_conditional_prior(
15591580 """
15601581
15611582 return self ._sample_conditional (
1562- idata = idata , group = "prior" , random_seed = random_seed , mvn_method = mvn_method , ** kwargs
1583+ idata = idata ,
1584+ group = "prior" ,
1585+ random_seed = random_seed ,
1586+ mvn_method = mvn_method ,
1587+ vectorize_draws = vectorize_draws ,
1588+ ** kwargs ,
15631589 )
15641590
15651591 def sample_conditional_posterior (
15661592 self ,
15671593 idata : InferenceData ,
15681594 random_seed : RandomState | None = None ,
15691595 mvn_method : Literal ["cholesky" , "eigh" , "svd" ] = "svd" ,
1596+ vectorize_draws : bool = True ,
15701597 ** kwargs ,
15711598 ):
15721599 """
@@ -1590,6 +1617,11 @@ def sample_conditional_posterior(
15901617 In general, if your model has measurement error, "cholesky" will be safe to use. Otherwise, "svd" is
15911618 recommended. "eigh" can also be tried if sampling with "svd" is very slow, but it is not as robust as "svd".
15921619
1620+ vectorize_draws : bool, default True
1621+ If True, sample all draws in a single vectorized operation. This is significantly faster but requires
1622+ more memory. It is strongly recommended to keep this True unless the state space is so large that memory
1623+ becomes an issue.
1624+
15931625 kwargs:
15941626 Additional keyword arguments are passed to pymc.sample_posterior_predictive
15951627
@@ -1602,7 +1634,12 @@ def sample_conditional_posterior(
16021634 """
16031635
16041636 return self ._sample_conditional (
1605- idata = idata , group = "posterior" , random_seed = random_seed , mvn_method = mvn_method , ** kwargs
1637+ idata = idata ,
1638+ group = "posterior" ,
1639+ random_seed = random_seed ,
1640+ mvn_method = mvn_method ,
1641+ vectorize_draws = vectorize_draws ,
1642+ ** kwargs ,
16061643 )
16071644
16081645 def sample_unconditional_prior (
0 commit comments