@@ -197,10 +197,9 @@ def sort_args(args):
197197 n_seq = len (sequence_names )
198198
199199 def step_fn (* args ):
200- seqs , state , non_seqs = args [:n_seq ], args [n_seq ], args [n_seq + 1 :]
201- non_seqs , rng = non_seqs [:- 1 ], non_seqs [- 1 ]
200+ seqs , (rng , state , * non_seqs ) = args [:n_seq ], args [n_seq :]
202201
203- c , d , T , Z , R , H , Q = sort_args (seqs + non_seqs )
202+ c , d , T , Z , R , H , Q = sort_args (( * seqs , * non_seqs ) )
204203 k = T .shape [0 ]
205204 a = state [:k ]
206205
@@ -219,7 +218,7 @@ def step_fn(*args):
219218
220219 next_state = pt .concatenate ([a_next , y_next ], axis = 0 )
221220
222- return next_state , { rng : next_rng }
221+ return next_rng , next_state
223222
224223 Z_init = Z_ if Z_ in non_sequences else Z_ [0 ]
225224 H_init = H_ if H_ in non_sequences else H_ [0 ]
@@ -229,13 +228,14 @@ def step_fn(*args):
229228
230229 init_dist_ = pt .concatenate ([init_x_ , init_y_ ], axis = 0 )
231230
232- statespace , updates = pytensor .scan (
231+ ss_rng , statespace = pytensor .scan (
233232 step_fn ,
234- outputs_info = [init_dist_ ],
233+ outputs_info = [rng , init_dist_ ],
235234 sequences = None if len (sequences ) == 0 else sequences ,
236- non_sequences = [* non_sequences , rng ],
235+ non_sequences = [* non_sequences ],
237236 n_steps = steps ,
238237 strict = True ,
238+ return_updates = False ,
239239 )
240240
241241 if append_x0 :
@@ -245,7 +245,6 @@ def step_fn(*args):
245245 statespace_ = statespace
246246 statespace_ = pt .specify_shape (statespace_ , (steps , None ))
247247
248- (ss_rng ,) = tuple (updates .values ())
249248 linear_gaussian_ss_op = LinearGaussianStateSpaceRV (
250249 inputs = [a0_ , P0_ , c_ , d_ , T_ , Z_ , R_ , H_ , Q_ , steps , rng ],
251250 outputs = [ss_rng , statespace_ ],
@@ -385,19 +384,22 @@ def rv_op(cls, mus, covs, logp, method="svd", size=None):
385384
386385 def step (mu , cov , rng ):
387386 new_rng , mvn = pm .MvNormal .dist (mu = mu , cov = cov , rng = rng , method = method ).owner .outputs
388- return mvn , { rng : new_rng }
387+ return new_rng , mvn
389388
390- mvn_seq , updates = pytensor .scan (
391- step , sequences = [mus_ , covs_ ], non_sequences = [rng ], strict = True , n_steps = mus_ .shape [0 ]
389+ seq_mvn_rng , mvn_seq = pytensor .scan (
390+ step ,
391+ sequences = [mus_ , covs_ ],
392+ outputs_info = [rng , None ],
393+ strict = True ,
394+ n_steps = mus_ .shape [0 ],
395+ return_updates = False ,
392396 )
393397 mvn_seq = pt .specify_shape (mvn_seq , mus .type .shape )
394398
395399 # Move time axis back to position -2 so batches are on the left
396400 if mvn_seq .ndim > 2 :
397401 mvn_seq = pt .moveaxis (mvn_seq , 0 , - 2 )
398402
399- (seq_mvn_rng ,) = tuple (updates .values ())
400-
401403 mvn_seq_op = KalmanFilterRV (
402404 inputs = [mus_ , covs_ , logp_ , rng ], outputs = [seq_mvn_rng , mvn_seq ], ndim_supp = 2
403405 )
0 commit comments