Skip to content

Commit fa470af

Browse files
Add vectorize_draws argument to sample methods
1 parent 51ad853 commit fa470af

File tree

3 files changed

+111
-30
lines changed

3 files changed

+111
-30
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,7 @@ def build_statespace_graph(
960960
mvn_method: Literal["cholesky", "eigh", "svd"] = "svd",
961961
save_kalman_filter_outputs_in_idata: bool = False,
962962
mode: str | None = None,
963+
vectorize_draws: bool = True,
963964
) -> None:
964965
"""
965966
Given a parameter vector `theta`, constructs the full computational graph describing the state space model and
@@ -1022,6 +1023,11 @@ def build_statespace_graph(
10221023
The `mode` argument is deprecated and will be removed in a future version. Pass ``mode`` to the
10231024
model constructor, or manually specify ``compile_kwargs`` in sampling functions instead.
10241025
1026+
vectorize_draws : bool, default True
1027+
If True, sample all draws in a single vectorized operation. This is significantly faster but requires
1028+
more memory. It is strongly recommended to keep this True unless the state space is so large that memory
1029+
becomes an issue.
1030+
10251031
"""
10261032
if mode is not None:
10271033
warnings.warn(
@@ -1078,6 +1084,7 @@ def build_statespace_graph(
10781084
observed=data,
10791085
dims=obs_dims,
10801086
method=mvn_method,
1087+
vectorize_draws=vectorize_draws,
10811088
)
10821089

10831090
self._fit_coords = pm_mod.coords.copy()
@@ -1271,6 +1278,7 @@ def _sample_conditional(
12711278
random_seed: RandomState | None = None,
12721279
data: pt.TensorLike | None = None,
12731280
mvn_method: Literal["cholesky", "eigh", "svd"] = "svd",
1281+
vectorize_draws: bool = True,
12741282
**kwargs,
12751283
):
12761284
"""
@@ -1300,6 +1308,11 @@ def _sample_conditional(
13001308
In general, if your model has measurement error, "cholesky" will be safe to use. Otherwise, "svd" is
13011309
recommended. "eigh" can also be tried if sampling with "svd" is very slow, but it is not as robust as "svd".
13021310
1311+
vectorize_draws : bool, default True
1312+
If True, sample all draws in a single vectorized operation. This is significantly faster but requires
1313+
more memory. It is strongly recommended to keep this True unless the state space is so large that memory
1314+
becomes an issue.
1315+
13031316
kwargs:
13041317
Additional keyword arguments are passed to pymc.sample_posterior_predictive
13051318
@@ -1355,6 +1368,7 @@ def _sample_conditional(
13551368
logp=dummy_ll,
13561369
dims=state_dims,
13571370
method=mvn_method,
1371+
vectorize_draws=vectorize_draws,
13581372
)
13591373

13601374
obs_mu = d + (Z @ mu[..., None]).squeeze(-1)
@@ -1367,6 +1381,7 @@ def _sample_conditional(
13671381
logp=dummy_ll,
13681382
dims=obs_dims,
13691383
method=mvn_method,
1384+
vectorize_draws=vectorize_draws,
13701385
)
13711386

13721387
# TODO: Remove this after pm.Flat initial values are fixed
@@ -1523,6 +1538,7 @@ def sample_conditional_prior(
15231538
idata: InferenceData,
15241539
random_seed: RandomState | None = None,
15251540
mvn_method: Literal["cholesky", "eigh", "svd"] = "svd",
1541+
vectorize_draws: bool = True,
15261542
**kwargs,
15271543
) -> InferenceData:
15281544
"""
@@ -1547,6 +1563,11 @@ def sample_conditional_prior(
15471563
In general, if your model has measurement error, "cholesky" will be safe to use. Otherwise, "svd" is
15481564
recommended. "eigh" can also be tried if sampling with "svd" is very slow, but it is not as robust as "svd".
15491565
1566+
vectorize_draws : bool, default True
1567+
If True, sample all draws in a single vectorized operation. This is significantly faster but requires
1568+
more memory. It is strongly recommended to keep this True unless the state space is so large that memory
1569+
becomes an issue.
1570+
15501571
kwargs:
15511572
Additional keyword arguments are passed to pymc.sample_posterior_predictive
15521573
@@ -1559,14 +1580,20 @@ def sample_conditional_prior(
15591580
"""
15601581

15611582
return self._sample_conditional(
1562-
idata=idata, group="prior", random_seed=random_seed, mvn_method=mvn_method, **kwargs
1583+
idata=idata,
1584+
group="prior",
1585+
random_seed=random_seed,
1586+
mvn_method=mvn_method,
1587+
vectorize_draws=vectorize_draws,
1588+
**kwargs,
15631589
)
15641590

15651591
def sample_conditional_posterior(
15661592
self,
15671593
idata: InferenceData,
15681594
random_seed: RandomState | None = None,
15691595
mvn_method: Literal["cholesky", "eigh", "svd"] = "svd",
1596+
vectorize_draws: bool = True,
15701597
**kwargs,
15711598
):
15721599
"""
@@ -1590,6 +1617,11 @@ def sample_conditional_posterior(
15901617
In general, if your model has measurement error, "cholesky" will be safe to use. Otherwise, "svd" is
15911618
recommended. "eigh" can also be tried if sampling with "svd" is very slow, but it is not as robust as "svd".
15921619
1620+
vectorize_draws : bool, default True
1621+
If True, sample all draws in a single vectorized operation. This is significantly faster but requires
1622+
more memory. It is strongly recommended to keep this True unless the state space is so large that memory
1623+
becomes an issue.
1624+
15931625
kwargs:
15941626
Additional keyword arguments are passed to pymc.sample_posterior_predictive
15951627
@@ -1602,7 +1634,12 @@ def sample_conditional_posterior(
16021634
"""
16031635

16041636
return self._sample_conditional(
1605-
idata=idata, group="posterior", random_seed=random_seed, mvn_method=mvn_method, **kwargs
1637+
idata=idata,
1638+
group="posterior",
1639+
random_seed=random_seed,
1640+
mvn_method=mvn_method,
1641+
vectorize_draws=vectorize_draws,
1642+
**kwargs,
16061643
)
16071644

16081645
def sample_unconditional_prior(

pymc_extras/statespace/filters/distributions.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pymc.distributions.shape_utils import get_support_shape_1d
1010
from pymc.logprob.abstract import _logprob
1111
from pytensor.graph.basic import Node
12+
from pytensor.tensor.random import multivariate_normal
1213

1314
floatX = pytensor.config.floatX
1415
COV_ZERO_TOL = 0
@@ -366,45 +367,58 @@ def __new__(cls, *args, **kwargs):
366367
return super().__new__(cls, *args, **kwargs)
367368

368369
@classmethod
369-
def dist(cls, mus, covs, logp, method="svd", **kwargs):
370-
return super().dist([mus, covs, logp], method=method, **kwargs)
370+
def dist(cls, mus, covs, logp, method="svd", vectorize_draws=True, **kwargs):
371+
mus, covs, logp = map(pt.as_tensor_variable, (mus, covs, logp))
372+
return super().dist(
373+
[mus, covs, logp], method=method, vectorize_draws=vectorize_draws, **kwargs
374+
)
371375

372376
@classmethod
373-
def rv_op(cls, mus, covs, logp, method="svd", size=None):
374-
# Batch dimensions (if any) will be on the far left, but scan requires time to be there instead
375-
if mus.ndim > 2:
376-
mus = pt.moveaxis(mus, -2, 0)
377-
if covs.ndim > 3:
378-
covs = pt.moveaxis(covs, -3, 0)
379-
380-
mus_, covs_ = mus.type(), covs.type()
381-
382-
logp_ = logp.type()
377+
def rv_op(cls, mus, covs, logp, method="svd", vectorize_draws=True, size=None):
383378
rng = pytensor.shared(np.random.default_rng())
379+
logp_ = logp.type()
384380

385-
def step(mu, cov, rng):
386-
new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method=method).owner.outputs
387-
return new_rng, mvn
388-
389-
seq_mvn_rng, mvn_seq = pytensor.scan(
390-
step,
391-
sequences=[mus_, covs_],
392-
outputs_info=[rng, None],
393-
strict=True,
394-
n_steps=mus_.shape[0],
395-
return_updates=False,
396-
)
397-
mvn_seq = pt.specify_shape(mvn_seq, mus.type.shape)
381+
if vectorize_draws:
382+
mus_, covs_ = mus.type(), covs.type()
383+
seq_mvn_rng, mvn_seq = multivariate_normal(
384+
mean=mus_, cov=covs_, rng=rng, method=method
385+
).owner.outputs
398386

399-
# Move time axis back to position -2 so batches are on the left
400-
if mvn_seq.ndim > 2:
401-
mvn_seq = pt.moveaxis(mvn_seq, 0, -2)
387+
else:
388+
# Batch dimensions (if any) will be on the far left, but scan requires time to be there instead
389+
if mus.ndim > 2:
390+
mus = pt.moveaxis(mus, -2, 0)
391+
if covs.ndim > 3:
392+
covs = pt.moveaxis(covs, -3, 0)
393+
394+
mus_, covs_ = mus.type(), covs.type()
395+
396+
def step(mu, cov, rng):
397+
new_rng, mvn = pm.MvNormal.dist(
398+
mu=mu, cov=cov, rng=rng, method=method
399+
).owner.outputs
400+
return new_rng, mvn
401+
402+
seq_mvn_rng, mvn_seq = pytensor.scan(
403+
step,
404+
sequences=[mus_, covs_],
405+
outputs_info=[rng, None],
406+
strict=True,
407+
n_steps=mus_.shape[0],
408+
return_updates=False,
409+
)
410+
mvn_seq = pt.specify_shape(mvn_seq, mus.type.shape)
411+
412+
# Move time axis back to position -2 so batches are on the left
413+
if mvn_seq.ndim > 2:
414+
mvn_seq = pt.moveaxis(mvn_seq, 0, -2)
402415

403416
mvn_seq_op = KalmanFilterRV(
404417
inputs=[mus_, covs_, logp_, rng], outputs=[seq_mvn_rng, mvn_seq], ndim_supp=2
405418
)
406419

407420
mvn_seq = mvn_seq_op(mus, covs, logp, rng)
421+
408422
return mvn_seq
409423

410424

tests/statespace/filters/test_distributions.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
from numpy.testing import assert_allclose
8+
from pytensor.graph.basic import equal_computations
89
from scipy.stats import multivariate_normal
910

1011
from pymc_extras.statespace import structural
@@ -268,3 +269,32 @@ def test_lgss_signature():
268269
)
269270
assert lgss.owner.op.ndim_supp == 2
270271
assert lgss.owner.op.ndims_params == [1, 2, 1, 1, 3, 2, 2, 2, 2]
272+
273+
274+
def test_sequence_mvnormal_vectorize_draws(rng):
275+
n_time = 50
276+
k_states = 3
277+
278+
mu_shape = (n_time, k_states)
279+
cov_shape = (n_time, k_states, k_states)
280+
logp_shape = (n_time,)
281+
282+
mus = rng.random(size=mu_shape).astype(floatX)
283+
covs = np.zeros(cov_shape, dtype=floatX)
284+
for idx in np.ndindex(cov_shape[:-2]):
285+
A = rng.random(size=(k_states, k_states)).astype(floatX)
286+
covs[idx] = A @ A.T + 0.1 * np.eye(k_states)
287+
logp = rng.random(size=logp_shape).astype(floatX)
288+
289+
seed = sum(map(ord, "test_sequence_mvnormal_vectorize_draws"))
290+
with pm.Model() as m1:
291+
x_vectorized = SequenceMvNormal("x", mus=mus, covs=covs, logp=logp, vectorize_draws=True)
292+
idata1 = pm.sample_prior_predictive(draws=5, random_seed=seed)
293+
294+
with pm.Model() as m2:
295+
x_sequential = SequenceMvNormal("x", mus=mus, covs=covs, logp=logp, vectorize_draws=False)
296+
idata2 = pm.sample_prior_predictive(draws=5, random_seed=seed)
297+
298+
assert_allclose(idata1.prior["x"].values, idata2.prior["x"].values, atol=ATOL, rtol=RTOL)
299+
300+
assert not equal_computations([x_vectorized], [x_sequential])

0 commit comments

Comments
 (0)