Skip to content

Commit ceb7ff3

Browse files
committed
Switch to new Scan API
1 parent 0057a81 commit ceb7ff3

File tree

9 files changed

+65
-54
lines changed

9 files changed

+65
-54
lines changed

notebooks/DFM_Example_(Coincident_Index).ipynb

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -930,19 +930,19 @@
930930
"</pre>\n"
931931
],
932932
"text/plain": [
933-
"\u001b[3m Model Requirements \u001b[0m\n",
933+
"\u001B[3m Model Requirements \u001B[0m\n",
934934
" \n",
935-
" \u001b[1m \u001b[0m\u001b[1mVariable \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mShape \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mConstraints \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1m Dimensions\u001b[0m\u001b[1m \u001b[0m \n",
935+
" \u001B[1m \u001B[0m\u001B[1mVariable \u001B[0m\u001B[1m \u001B[0m \u001B[1m \u001B[0m\u001B[1mShape \u001B[0m\u001B[1m \u001B[0m \u001B[1m \u001B[0m\u001B[1mConstraints \u001B[0m\u001B[1m \u001B[0m \u001B[1m \u001B[0m\u001B[1m Dimensions\u001B[0m\u001B[1m \u001B[0m \n",
936936
" ────────────────────────────────────────────────────────────────────────────────────────── \n",
937-
" x0 \u001b[1m(\u001b[0m\u001b[1;36m10\u001b[0m,\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'state'\u001b[0m,\u001b[1m)\u001b[0m \n",
938-
" P0 \u001b[1m(\u001b[0m\u001b[1;36m10\u001b[0m, \u001b[1;36m10\u001b[0m\u001b[1m)\u001b[0m Positive Semi-definite \u001b[1m(\u001b[0m\u001b[32m'state'\u001b[0m, \u001b[32m'state_aux'\u001b[0m\u001b[1m)\u001b[0m \n",
939-
" factor_loadings \u001b[1m(\u001b[0m\u001b[1;36m4\u001b[0m, \u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'observed_state'\u001b[0m, \u001b[32m'factor'\u001b[0m\u001b[1m)\u001b[0m \n",
940-
" factor_ar \u001b[1m(\u001b[0m\u001b[1;36m1\u001b[0m, \u001b[1;36m2\u001b[0m\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'factor'\u001b[0m, \u001b[32m'lag_ar'\u001b[0m\u001b[1m)\u001b[0m \n",
941-
" error_ar \u001b[1m(\u001b[0m\u001b[1;36m4\u001b[0m, \u001b[1;36m2\u001b[0m\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'observed_state'\u001b[0m, \u001b[32m'error_lag_ar'\u001b[0m\u001b[1m)\u001b[0m \n",
942-
" error_sigma \u001b[1m(\u001b[0m\u001b[1;36m4\u001b[0m,\u001b[1m)\u001b[0m Positive \u001b[1m(\u001b[0m\u001b[32m'observed_state'\u001b[0m,\u001b[1m)\u001b[0m \n",
937+
" x0 \u001B[1m(\u001B[0m\u001B[1;36m10\u001B[0m,\u001B[1m)\u001B[0m \u001B[1m(\u001B[0m\u001B[32m'state'\u001B[0m,\u001B[1m)\u001B[0m \n",
938+
" P0 \u001B[1m(\u001B[0m\u001B[1;36m10\u001B[0m, \u001B[1;36m10\u001B[0m\u001B[1m)\u001B[0m Positive Semi-definite \u001B[1m(\u001B[0m\u001B[32m'state'\u001B[0m, \u001B[32m'state_aux'\u001B[0m\u001B[1m)\u001B[0m \n",
939+
" factor_loadings \u001B[1m(\u001B[0m\u001B[1;36m4\u001B[0m, \u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m \u001B[1m(\u001B[0m\u001B[32m'observed_state'\u001B[0m, \u001B[32m'factor'\u001B[0m\u001B[1m)\u001B[0m \n",
940+
" factor_ar \u001B[1m(\u001B[0m\u001B[1;36m1\u001B[0m, \u001B[1;36m2\u001B[0m\u001B[1m)\u001B[0m \u001B[1m(\u001B[0m\u001B[32m'factor'\u001B[0m, \u001B[32m'lag_ar'\u001B[0m\u001B[1m)\u001B[0m \n",
941+
" error_ar \u001B[1m(\u001B[0m\u001B[1;36m4\u001B[0m, \u001B[1;36m2\u001B[0m\u001B[1m)\u001B[0m \u001B[1m(\u001B[0m\u001B[32m'observed_state'\u001B[0m, \u001B[32m'error_lag_ar'\u001B[0m\u001B[1m)\u001B[0m \n",
942+
" error_sigma \u001B[1m(\u001B[0m\u001B[1;36m4\u001B[0m,\u001B[1m)\u001B[0m Positive \u001B[1m(\u001B[0m\u001B[32m'observed_state'\u001B[0m,\u001B[1m)\u001B[0m \n",
943943
" \n",
944-
"\u001b[2;3m These parameters should be assigned priors inside a PyMC model block before calling the \u001b[0m\n",
945-
"\u001b[2;3m build_statespace_graph method. \u001b[0m\n"
944+
"\u001B[2;3m These parameters should be assigned priors inside a PyMC model block before calling the \u001B[0m\n",
945+
"\u001B[2;3m build_statespace_graph method. \u001B[0m\n"
946946
]
947947
},
948948
"metadata": {},
@@ -1759,8 +1759,11 @@
17591759
" K = pt.linalg.solve(F, PZT.T, assume_a=\"pos\", check_finite=False).T\n",
17601760
" return K\n",
17611761
"\n",
1762-
" ss_kalman_gain, updates = pytensor.scan(\n",
1763-
" step, non_sequences=[Z, T, H], sequences=[predicted_covariance]\n",
1762+
" ss_kalman_gain = pytensor.scan(\n",
1763+
" step,\n",
1764+
" non_sequences=[Z, T, H],\n",
1765+
" sequences=[predicted_covariance],\n",
1766+
" return_updates=False,\n",
17641767
" )\n",
17651768
" # Get the last Kalman gain (steady state)\n",
17661769
" ss_kalman_gain = ss_kalman_gain[-1]\n",

notebooks/discrete_markov_chain.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,13 +558,14 @@
558558
"\n",
559559
" return y_out\n",
560560
"\n",
561-
" result, updates = pytensor.scan(\n",
561+
" result = pytensor.scan(\n",
562562
" AR_step,\n",
563563
" sequences=[\n",
564564
" {\"input\": hidden_states, \"taps\": [0, -1, -2, -3, -4]},\n",
565565
" {\"input\": y, \"taps\": [-1, -2, -3, -4]},\n",
566566
" ],\n",
567567
" non_sequences=[state_mus, ar_coefs],\n",
568+
" return_updates=False,\n",
568569
" )\n",
569570
"\n",
570571
" sigma = pm.HalfCauchy(\"sigma\", 0.8)\n",

pymc_extras/distributions/timeseries.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -196,21 +196,20 @@ def rv_op(cls, P, steps, init_dist, n_lags, size=None):
196196
state_rng = pytensor.shared(np.random.default_rng())
197197

198198
def transition(*args):
199-
*states, transition_probs, old_rng = args
199+
old_rng, *states, transition_probs = args
200200
p = transition_probs[tuple(states)]
201201
next_rng, next_state = pm.Categorical.dist(p=p, rng=old_rng).owner.outputs
202-
return next_state, {old_rng: next_rng}
202+
return next_rng, next_state
203203

204-
markov_chain, state_updates = pytensor.scan(
204+
state_next_rng, markov_chain = pytensor.scan(
205205
transition,
206-
non_sequences=[P_, state_rng],
207-
outputs_info=_make_outputs_info(n_lags, init_dist_),
206+
outputs_info=[state_rng, *_make_outputs_info(n_lags, init_dist_)],
207+
non_sequences=[P_],
208208
n_steps=steps_,
209209
strict=True,
210+
return_updates=False,
210211
)
211212

212-
(state_next_rng,) = tuple(state_updates.values())
213-
214213
discrete_mc_ = pt.moveaxis(pt.concatenate([init_dist_, markov_chain], axis=0), 0, -1)
215214

216215
discrete_mc_op = DiscreteMarkovChainRV(
@@ -243,12 +242,13 @@ def greedy_transition(*args):
243242
p = transition_probs[tuple(states)]
244243
return pt.argmax(p)
245244

246-
chain_moment, moment_updates = pytensor.scan(
245+
chain_moment = pytensor.scan(
247246
greedy_transition,
248247
non_sequences=[P, state_rng],
249248
outputs_info=_make_outputs_info(n_lags, init_dist),
250249
n_steps=steps,
251250
strict=True,
251+
return_updates=False,
252252
)
253253
chain_moment = pt.concatenate([init_dist_moment, chain_moment])
254254
return chain_moment

pymc_extras/inference/pathfinder/pathfinder.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -278,12 +278,13 @@ def compute_alpha_l(s_l, z_l, alpha_lm1) -> TensorVariable:
278278
z = pt.diff(g, axis=0)
279279
alpha_l_init = pt.ones(N)
280280

281-
alpha, _ = pytensor.scan(
281+
alpha = pytensor.scan(
282282
fn=compute_alpha_l,
283283
outputs_info=alpha_l_init,
284284
sequences=[s, z],
285285
n_steps=Lp1 - 1,
286286
allow_gc=False,
287+
return_updates=False,
287288
)
288289

289290
# assert np.all(alpha.eval() > 0), "alpha cannot be negative"
@@ -334,11 +335,12 @@ def chi_update(diff_l, chi_lm1) -> TensorVariable:
334335
return pt.set_subtensor(chi_l[j_last], diff_l)
335336

336337
chi_init = pt.zeros((J, N))
337-
chi_mat, _ = pytensor.scan(
338+
chi_mat = pytensor.scan(
338339
fn=chi_update,
339340
outputs_info=chi_init,
340341
sequences=[diff],
341342
allow_gc=False,
343+
return_updates=False,
342344
)
343345

344346
chi_mat = pt.matrix_transpose(chi_mat)
@@ -377,14 +379,14 @@ def get_chi_matrix_2(diff: TensorVariable, J: TensorConstant) -> TensorVariable:
377379
eta = pt.diagonal(E, axis1=-2, axis2=-1)
378380

379381
# beta: (L, N, 2J)
380-
alpha_diag, _ = pytensor.scan(lambda a: pt.diag(a), sequences=[alpha])
382+
alpha_diag = pytensor.scan(lambda a: pt.diag(a), sequences=[alpha], return_updates=False)
381383
beta = pt.concatenate([alpha_diag @ Z, S], axis=-1)
382384

383385
# more performant and numerically precise to use solve than inverse: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.inv.html
384386

385387
# E_inv: (L, J, J)
386388
E_inv = pt.slinalg.solve_triangular(E, Ij, check_finite=False)
387-
eta_diag, _ = pytensor.scan(pt.diag, sequences=[eta])
389+
eta_diag = pytensor.scan(pt.diag, sequences=[eta], return_updates=False)
388390

389391
# block_dd: (L, J, J)
390392
block_dd = (
@@ -530,7 +532,9 @@ def bfgs_sample_sparse(
530532

531533
# qr_input: (L, N, 2J)
532534
qr_input = inv_sqrt_alpha_diag @ beta
533-
(Q, R), _ = pytensor.scan(fn=pt.linalg.qr, sequences=[qr_input], allow_gc=False)
535+
Q, R = pytensor.scan(
536+
fn=pt.linalg.qr, sequences=[qr_input], allow_gc=False, return_updates=False
537+
)
534538

535539
IdN = pt.eye(R.shape[1])[None, ...]
536540
IdN += IdN * REGULARISATION_TERM
@@ -623,10 +627,11 @@ def bfgs_sample(
623627

624628
L, N, JJ = beta.shape
625629

626-
(alpha_diag, inv_sqrt_alpha_diag, sqrt_alpha_diag), _ = pytensor.scan(
630+
alpha_diag, inv_sqrt_alpha_diag, sqrt_alpha_diag = pytensor.scan(
627631
lambda a: [pt.diag(a), pt.diag(pt.sqrt(1.0 / a)), pt.diag(pt.sqrt(a))],
628632
sequences=[alpha],
629633
allow_gc=False,
634+
return_updates=False,
630635
)
631636

632637
u = pt.random.normal(size=(L, num_samples, N))

pymc_extras/model/marginal/distributions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,11 +282,12 @@ def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inpu
282282
def logp_fn(marginalized_rv_const, *non_sequences):
283283
return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const})
284284

285-
joint_logps, _ = scan_map(
285+
joint_logps = scan_map(
286286
fn=logp_fn,
287287
sequences=marginalized_rv_domain_tensor,
288288
non_sequences=[*values, *inputs],
289289
mode=Mode().including("local_remove_check_parameter"),
290+
return_updates=False,
290291
)
291292

292293
joint_logp = pt.logsumexp(joint_logps, axis=0)
@@ -350,12 +351,13 @@ def step_alpha(logp_emission, log_alpha, log_P):
350351

351352
P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2)
352353
log_P = pt.shape_padright(pt.log(P), P_bcast_dims)
353-
log_alpha_seq, _ = scan(
354+
log_alpha_seq = scan(
354355
step_alpha,
355356
non_sequences=[log_P],
356357
outputs_info=[log_alpha_init],
357358
# Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value
358359
sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0),
360+
return_updates=False,
359361
)
360362
# Final logp is just the sum of the last scan state
361363
joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0)

pymc_extras/statespace/core/statespace.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2500,13 +2500,14 @@ def irf_step(shock, x, c, T, R):
25002500
next_x = c + T @ x + R @ shock
25012501
return next_x
25022502

2503-
irf, updates = pytensor.scan(
2503+
irf = pytensor.scan(
25042504
irf_step,
25052505
sequences=[shock_trajectory],
25062506
outputs_info=[x0],
25072507
non_sequences=[c, T, R],
25082508
n_steps=n_steps,
25092509
strict=True,
2510+
return_updates=False,
25102511
)
25112512

25122513
pm.Deterministic("irf", irf, dims=[TIME_DIM, ALL_STATE_DIM])

pymc_extras/statespace/filters/distributions.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,9 @@ def sort_args(args):
197197
n_seq = len(sequence_names)
198198

199199
def step_fn(*args):
200-
seqs, state, non_seqs = args[:n_seq], args[n_seq], args[n_seq + 1 :]
201-
non_seqs, rng = non_seqs[:-1], non_seqs[-1]
200+
seqs, (rng, state, *non_seqs) = args[:n_seq], args[n_seq:]
202201

203-
c, d, T, Z, R, H, Q = sort_args(seqs + non_seqs)
202+
c, d, T, Z, R, H, Q = sort_args((*seqs, *non_seqs))
204203
k = T.shape[0]
205204
a = state[:k]
206205

@@ -219,7 +218,7 @@ def step_fn(*args):
219218

220219
next_state = pt.concatenate([a_next, y_next], axis=0)
221220

222-
return next_state, {rng: next_rng}
221+
return next_rng, next_state
223222

224223
Z_init = Z_ if Z_ in non_sequences else Z_[0]
225224
H_init = H_ if H_ in non_sequences else H_[0]
@@ -229,13 +228,14 @@ def step_fn(*args):
229228

230229
init_dist_ = pt.concatenate([init_x_, init_y_], axis=0)
231230

232-
statespace, updates = pytensor.scan(
231+
ss_rng, statespace = pytensor.scan(
233232
step_fn,
234-
outputs_info=[init_dist_],
233+
outputs_info=[rng, init_dist_],
235234
sequences=None if len(sequences) == 0 else sequences,
236-
non_sequences=[*non_sequences, rng],
235+
non_sequences=[*non_sequences],
237236
n_steps=steps,
238237
strict=True,
238+
return_updates=False,
239239
)
240240

241241
if append_x0:
@@ -245,7 +245,6 @@ def step_fn(*args):
245245
statespace_ = statespace
246246
statespace_ = pt.specify_shape(statespace_, (steps, None))
247247

248-
(ss_rng,) = tuple(updates.values())
249248
linear_gaussian_ss_op = LinearGaussianStateSpaceRV(
250249
inputs=[a0_, P0_, c_, d_, T_, Z_, R_, H_, Q_, steps, rng],
251250
outputs=[ss_rng, statespace_],
@@ -385,19 +384,22 @@ def rv_op(cls, mus, covs, logp, method="svd", size=None):
385384

386385
def step(mu, cov, rng):
387386
new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method=method).owner.outputs
388-
return mvn, {rng: new_rng}
387+
return new_rng, mvn
389388

390-
mvn_seq, updates = pytensor.scan(
391-
step, sequences=[mus_, covs_], non_sequences=[rng], strict=True, n_steps=mus_.shape[0]
389+
seq_mvn_rng, mvn_seq = pytensor.scan(
390+
step,
391+
sequences=[mus_, covs_],
392+
outputs_info=[rng, None],
393+
strict=True,
394+
n_steps=mus_.shape[0],
395+
return_updates=False,
392396
)
393397
mvn_seq = pt.specify_shape(mvn_seq, mus.type.shape)
394398

395399
# Move time axis back to position -2 so batches are on the left
396400
if mvn_seq.ndim > 2:
397401
mvn_seq = pt.moveaxis(mvn_seq, 0, -2)
398402

399-
(seq_mvn_rng,) = tuple(updates.values())
400-
401403
mvn_seq_op = KalmanFilterRV(
402404
inputs=[mus_, covs_, logp_, rng], outputs=[seq_mvn_rng, mvn_seq], ndim_supp=2
403405
)

pymc_extras/statespace/filters/kalman_filter.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,9 @@ def build_graph(
148148
R,
149149
H,
150150
Q,
151-
return_updates=False,
152151
missing_fill_value=None,
153152
cov_jitter=None,
154-
) -> list[TensorVariable] | tuple[list[TensorVariable], dict]:
153+
) -> list[TensorVariable]:
155154
"""
156155
Construct the computation graph for the Kalman filter. See [1] for details.
157156
@@ -211,20 +210,17 @@ def build_graph(
211210
if len(sequences) > 0:
212211
sequences = self.add_check_on_time_varying_shapes(data, sequences)
213212

214-
results, updates = pytensor.scan(
213+
results = pytensor.scan(
215214
self.kalman_step,
216215
sequences=[data, *sequences],
217216
outputs_info=[None, a0, None, None, P0, None, None],
218217
non_sequences=non_sequences,
219218
name="forward_kalman_pass",
220219
strict=False,
220+
return_updates=False,
221221
)
222222

223-
filter_results = self._postprocess_scan_results(results, a0, P0, n=data.type.shape[0])
224-
225-
if return_updates:
226-
return filter_results, updates
227-
return filter_results
223+
return self._postprocess_scan_results(results, a0, P0, n=data.type.shape[0])
228224

229225
def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]:
230226
"""
@@ -786,11 +782,12 @@ def kalman_step(self, y, a, P, c, d, T, Z, R, H, Q):
786782
H_masked = W.dot(H)
787783
y_masked = pt.set_subtensor(y[nan_mask], 0.0)
788784

789-
result, updates = pytensor.scan(
785+
result = pytensor.scan(
790786
self._univariate_inner_filter_step,
791787
sequences=[y_masked, Z_masked, d, pt.diag(H_masked), nan_mask],
792788
outputs_info=[a, P, None, None, None],
793789
name="univariate_inner_scan",
790+
return_updates=False,
794791
)
795792

796793
a_filtered, P_filtered, obs_mu, obs_cov, ll_inner = result

pymc_extras/statespace/filters/kalman_smoother.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,16 @@ def build_graph(
7676
self.seq_names = seq_names
7777
self.non_seq_names = non_seq_names
7878

79-
smoother_result, updates = pytensor.scan(
79+
smoothed_states, smoothed_covariances = pytensor.scan(
8080
self.smoother_step,
8181
sequences=[filtered_states[:-1], filtered_covariances[:-1], *sequences],
8282
outputs_info=[a_last, P_last],
8383
non_sequences=non_sequences,
8484
go_backwards=True,
8585
name="kalman_smoother",
86+
return_updates=False,
8687
)
8788

88-
smoothed_states, smoothed_covariances = smoother_result
8989
smoothed_states = pt.concatenate(
9090
[smoothed_states[::-1], pt.expand_dims(a_last, axis=(0,))], axis=0
9191
)

0 commit comments

Comments
 (0)