Skip to content

Commit 4d7daba

Browse files
Expose method argument to MvNormals used in statespace distributions when doing post-estimation tasks
1 parent c8859ed commit 4d7daba

File tree

2 files changed

+72
-15
lines changed

2 files changed

+72
-15
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,6 +1109,7 @@ def _sample_conditional(
11091109
group: str,
11101110
random_seed: RandomState | None = None,
11111111
data: pt.TensorLike | None = None,
1112+
method: str = "svd",
11121113
**kwargs,
11131114
):
11141115
"""
@@ -1130,6 +1131,11 @@ def _sample_conditional(
11301131
Observed data on which to condition the model. If not provided, the function will use the data that was
11311132
provided when the model was built.
11321133
1134+
method: str
1135+
Method used to compute draws from multivariate normal. One of "cholesky", "eig", or "svd". "cholesky" is
1136+
fastest, but least robust to ill-conditioned matrices, while "svd" is slow but extermely robust. Default
1137+
is "svd".
1138+
11331139
kwargs:
11341140
Additional keyword arguments are passed to pymc.sample_posterior_predictive
11351141
@@ -1181,6 +1187,7 @@ def _sample_conditional(
11811187
covs=cov,
11821188
logp=dummy_ll,
11831189
dims=state_dims,
1190+
method=method,
11841191
)
11851192

11861193
obs_mu = (Z @ mu[..., None]).squeeze(-1)
@@ -1192,6 +1199,7 @@ def _sample_conditional(
11921199
covs=obs_cov,
11931200
logp=dummy_ll,
11941201
dims=obs_dims,
1202+
method=method,
11951203
)
11961204

11971205
# TODO: Remove this after pm.Flat initial values are fixed
@@ -1222,6 +1230,7 @@ def _sample_unconditional(
12221230
steps: int | None = None,
12231231
use_data_time_dim: bool = False,
12241232
random_seed: RandomState | None = None,
1233+
method: str = "svd",
12251234
**kwargs,
12261235
):
12271236
"""
@@ -1251,6 +1260,11 @@ def _sample_unconditional(
12511260
random_seed : int, RandomState or Generator, optional
12521261
Seed for the random number generator.
12531262
1263+
method: str
1264+
Method used to compute draws from multivariate normal. One of "cholesky", "eig", or "svd". "cholesky" is
1265+
fastest, but least robust to ill-conditioned matrices, while "svd" is slow but extermely robust. Default
1266+
is "svd".
1267+
12541268
kwargs:
12551269
Additional keyword arguments are passed to pymc.sample_posterior_predictive
12561270
@@ -1309,6 +1323,7 @@ def _sample_unconditional(
13091323
steps=steps,
13101324
dims=dims,
13111325
mode=self._fit_mode,
1326+
method=method,
13121327
sequence_names=self.kalman_filter.seq_names,
13131328
k_endog=self.k_endog,
13141329
)
@@ -1331,7 +1346,7 @@ def _sample_unconditional(
13311346
return idata_unconditional.posterior_predictive
13321347

13331348
def sample_conditional_prior(
1334-
self, idata: InferenceData, random_seed: RandomState | None = None, **kwargs
1349+
self, idata: InferenceData, random_seed: RandomState | None = None, method="svd", **kwargs
13351350
) -> InferenceData:
13361351
"""
13371352
Sample from the conditional prior; that is, given parameter draws from the prior distribution,
@@ -1347,6 +1362,11 @@ def sample_conditional_prior(
13471362
random_seed : int, RandomState or Generator, optional
13481363
Seed for the random number generator.
13491364
1365+
method: str
1366+
Method used to compute draws from multivariate normal. One of "cholesky", "eig", or "svd". "cholesky" is
1367+
fastest, but least robust to ill-conditioned matrices, while "svd" is slow but extermely robust. Default
1368+
is "svd".
1369+
13501370
kwargs:
13511371
Additional keyword arguments are passed to pymc.sample_posterior_predictive
13521372
@@ -1358,10 +1378,10 @@ def sample_conditional_prior(
13581378
"predicted_prior", and "smoothed_prior".
13591379
"""
13601380

1361-
return self._sample_conditional(idata, "prior", random_seed, **kwargs)
1381+
return self._sample_conditional(idata, "prior", random_seed, method, **kwargs)
13621382

13631383
def sample_conditional_posterior(
1364-
self, idata: InferenceData, random_seed: RandomState | None = None, **kwargs
1384+
self, idata: InferenceData, random_seed: RandomState | None = None, method="svd", **kwargs
13651385
):
13661386
"""
13671387
Sample from the conditional posterior; that is, given parameter draws from the posterior distribution,
@@ -1376,6 +1396,11 @@ def sample_conditional_posterior(
13761396
random_seed : int, RandomState or Generator, optional
13771397
Seed for the random number generator.
13781398
1399+
method: str
1400+
Method used to compute draws from multivariate normal. One of "cholesky", "eig", or "svd". "cholesky" is
1401+
fastest, but least robust to ill-conditioned matrices, while "svd" is slow but extermely robust. Default
1402+
is "svd".
1403+
13791404
kwargs:
13801405
Additional keyword arguments are passed to pymc.sample_posterior_predictive
13811406
@@ -1387,14 +1412,15 @@ def sample_conditional_posterior(
13871412
"predicted_posterior", and "smoothed_posterior".
13881413
"""
13891414

1390-
return self._sample_conditional(idata, "posterior", random_seed, **kwargs)
1415+
return self._sample_conditional(idata, "posterior", random_seed, method, **kwargs)
13911416

13921417
def sample_unconditional_prior(
13931418
self,
13941419
idata: InferenceData,
13951420
steps: int | None = None,
13961421
use_data_time_dim: bool = False,
13971422
random_seed: RandomState | None = None,
1423+
method="svd",
13981424
**kwargs,
13991425
) -> InferenceData:
14001426
"""
@@ -1423,6 +1449,11 @@ def sample_unconditional_prior(
14231449
random_seed : int, RandomState or Generator, optional
14241450
Seed for the random number generator.
14251451
1452+
method: str
1453+
Method used to compute draws from multivariate normal. One of "cholesky", "eig", or "svd". "cholesky" is
1454+
fastest, but least robust to ill-conditioned matrices, while "svd" is slow but extermely robust. Default
1455+
is "svd".
1456+
14261457
kwargs:
14271458
Additional keyword arguments are passed to pymc.sample_posterior_predictive
14281459
@@ -1439,7 +1470,7 @@ def sample_unconditional_prior(
14391470
"""
14401471

14411472
return self._sample_unconditional(
1442-
idata, "prior", steps, use_data_time_dim, random_seed, **kwargs
1473+
idata, "prior", steps, use_data_time_dim, random_seed, method, **kwargs
14431474
)
14441475

14451476
def sample_unconditional_posterior(
@@ -1448,6 +1479,7 @@ def sample_unconditional_posterior(
14481479
steps: int | None = None,
14491480
use_data_time_dim: bool = False,
14501481
random_seed: RandomState | None = None,
1482+
method="svd",
14511483
**kwargs,
14521484
) -> InferenceData:
14531485
"""
@@ -1477,6 +1509,11 @@ def sample_unconditional_posterior(
14771509
random_seed : int, RandomState or Generator, optional
14781510
Seed for the random number generator.
14791511
1512+
method: str
1513+
Method used to compute draws from multivariate normal. One of "cholesky", "eig", or "svd". "cholesky" is
1514+
fastest, but least robust to ill-conditioned matrices, while "svd" is slow but extermely robust. Default
1515+
is "svd".
1516+
14801517
Returns
14811518
-------
14821519
InferenceData
@@ -1490,7 +1527,7 @@ def sample_unconditional_posterior(
14901527
"""
14911528

14921529
return self._sample_unconditional(
1493-
idata, "posterior", steps, use_data_time_dim, random_seed, **kwargs
1530+
idata, "posterior", steps, use_data_time_dim, random_seed, method, **kwargs
14941531
)
14951532

14961533
def sample_statespace_matrices(
@@ -1933,6 +1970,7 @@ def forecast(
19331970
filter_output="smoothed",
19341971
random_seed: RandomState | None = None,
19351972
verbose: bool = True,
1973+
method: str = "svd",
19361974
**kwargs,
19371975
) -> InferenceData:
19381976
"""
@@ -1989,6 +2027,11 @@ def forecast(
19892027
verbose: bool, default=True
19902028
Whether to print diagnostic information about forecasting.
19912029
2030+
method: str
2031+
Method used to compute draws from multivariate normal. One of "cholesky", "eig", or "svd". "cholesky" is
2032+
fastest, but least robust to ill-conditioned matrices, while "svd" is slow but extermely robust. Default
2033+
is "svd".
2034+
19922035
kwargs:
19932036
Additional keyword arguments are passed to pymc.sample_posterior_predictive
19942037
@@ -2098,6 +2141,7 @@ def forecast(
20982141
sequence_names=self.kalman_filter.seq_names,
20992142
k_endog=self.k_endog,
21002143
append_x0=False,
2144+
method=method,
21012145
)
21022146

21032147
forecast_model.rvs_to_initial_values = {
@@ -2126,6 +2170,7 @@ def impulse_response_function(
21262170
shock_trajectory: np.ndarray | None = None,
21272171
orthogonalize_shocks: bool = False,
21282172
random_seed: RandomState | None = None,
2173+
method="svd",
21292174
**kwargs,
21302175
):
21312176
"""
@@ -2177,6 +2222,11 @@ def impulse_response_function(
21772222
random_seed : int, RandomState or Generator, optional
21782223
Seed for the random number generator.
21792224
2225+
method: str
2226+
Method used to compute draws from multivariate normal. One of "cholesky", "eig", or "svd". "cholesky" is
2227+
fastest, but least robust to ill-conditioned matrices, while "svd" is slow but extermely robust. Default
2228+
is "svd".
2229+
21802230
kwargs:
21812231
Additional keyword arguments are passed to pymc.sample_posterior_predictive
21822232
@@ -2236,7 +2286,7 @@ def impulse_response_function(
22362286
shock_trajectory = pt.zeros((n_steps, self.k_posdef))
22372287
if Q is not None:
22382288
init_shock = pm.MvNormal(
2239-
"initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM], method="svd"
2289+
"initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM], method=method
22402290
)
22412291
else:
22422292
init_shock = pm.Deterministic(

pymc_extras/statespace/filters/distributions.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def __new__(
7272
mode=None,
7373
sequence_names=None,
7474
append_x0=True,
75+
method="svd",
7576
**kwargs,
7677
):
7778
# Ignore dims in support shape because they are just passed along to the "observed" and "latent" distributions
@@ -100,6 +101,7 @@ def __new__(
100101
mode=mode,
101102
sequence_names=sequence_names,
102103
append_x0=append_x0,
104+
method=method,
103105
**kwargs,
104106
)
105107

@@ -119,6 +121,7 @@ def dist(
119121
mode=None,
120122
sequence_names=None,
121123
append_x0=True,
124+
method="svd",
122125
**kwargs,
123126
):
124127
steps = get_support_shape_1d(
@@ -135,6 +138,7 @@ def dist(
135138
mode=mode,
136139
sequence_names=sequence_names,
137140
append_x0=append_x0,
141+
method=method,
138142
**kwargs,
139143
)
140144

@@ -155,6 +159,7 @@ def rv_op(
155159
mode=None,
156160
sequence_names=None,
157161
append_x0=True,
162+
method="svd",
158163
):
159164
if sequence_names is None:
160165
sequence_names = []
@@ -205,10 +210,10 @@ def step_fn(*args):
205210
a = state[:k]
206211

207212
middle_rng, a_innovation = pm.MvNormal.dist(
208-
mu=0, cov=Q, rng=rng, method="svd"
213+
mu=0, cov=Q, rng=rng, method=method
209214
).owner.outputs
210215
next_rng, y_innovation = pm.MvNormal.dist(
211-
mu=0, cov=H, rng=middle_rng, method="svd"
216+
mu=0, cov=H, rng=middle_rng, method=method
212217
).owner.outputs
213218

214219
a_mu = c + T @ a
@@ -224,8 +229,8 @@ def step_fn(*args):
224229
Z_init = Z_ if Z_ in non_sequences else Z_[0]
225230
H_init = H_ if H_ in non_sequences else H_[0]
226231

227-
init_x_ = pm.MvNormal.dist(a0_, P0_, rng=rng, method="svd")
228-
init_y_ = pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng, method="svd")
232+
init_x_ = pm.MvNormal.dist(a0_, P0_, rng=rng, method=method)
233+
init_y_ = pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng, method=method)
229234

230235
init_dist_ = pt.concatenate([init_x_, init_y_], axis=0)
231236

@@ -281,6 +286,7 @@ def __new__(
281286
sequence_names=None,
282287
mode=None,
283288
append_x0=True,
289+
method="svd",
284290
**kwargs,
285291
):
286292
dims = kwargs.pop("dims", None)
@@ -310,6 +316,7 @@ def __new__(
310316
mode=mode,
311317
sequence_names=sequence_names,
312318
append_x0=append_x0,
319+
method=method,
313320
**kwargs,
314321
)
315322
latent_obs_combined = pt.specify_shape(latent_obs_combined, (steps + int(append_x0), None))
@@ -368,11 +375,11 @@ def __new__(cls, *args, **kwargs):
368375
return super().__new__(cls, *args, **kwargs)
369376

370377
@classmethod
371-
def dist(cls, mus, covs, logp, **kwargs):
372-
return super().dist([mus, covs, logp], **kwargs)
378+
def dist(cls, mus, covs, logp, method="svd", **kwargs):
379+
return super().dist([mus, covs, logp], method=method, **kwargs)
373380

374381
@classmethod
375-
def rv_op(cls, mus, covs, logp, size=None):
382+
def rv_op(cls, mus, covs, logp, method="svd", size=None):
376383
# Batch dimensions (if any) will be on the far left, but scan requires time to be there instead
377384
if mus.ndim > 2:
378385
mus = pt.moveaxis(mus, -2, 0)
@@ -385,7 +392,7 @@ def rv_op(cls, mus, covs, logp, size=None):
385392
rng = pytensor.shared(np.random.default_rng())
386393

387394
def step(mu, cov, rng):
388-
new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method="svd").owner.outputs
395+
new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method=method).owner.outputs
389396
return mvn, {rng: new_rng}
390397

391398
mvn_seq, updates = pytensor.scan(

0 commit comments

Comments
 (0)