@@ -375,7 +375,6 @@ def dist(cls, mus, covs, logp, **kwargs):
375
375
def rv_op (cls , mus , covs , logp , size = None ):
376
376
# Batch dimensions (if any) will be on the far left, but scan requires time to be there instead
377
377
mus_ , covs_ = mus .type (), covs .type ()
378
- print (f"mus_.type.shape: { mus_ .type .shape } , covs_.type.shape: { covs_ .type .shape } " )
379
378
380
379
logp_ = logp .type ()
381
380
rng = pytensor .shared (np .random .default_rng ())
@@ -385,7 +384,6 @@ def recursion(mus, covs, rng):
385
384
mus = pt .moveaxis (mus , - 2 , 0 )
386
385
if covs .ndim > 3 :
387
386
covs = pt .moveaxis (covs , - 3 , 0 )
388
- print (f"mus.type.shape: { mus .type .shape } , covs.type.shape: { covs .type .shape } " )
389
387
390
388
def step (mu , cov , rng ):
391
389
new_rng , mvn = pm .MvNormal .dist (mu = mu , cov = cov , rng = rng , method = "svd" ).owner .outputs
@@ -394,32 +392,26 @@ def step(mu, cov, rng):
394
392
mvn_seq , updates = pytensor .scan (
395
393
step , sequences = [mus , covs ], non_sequences = [rng ], strict = True , n_steps = mus .shape [0 ]
396
394
)
397
- print (f"mvn_seq.type.shape: { mvn_seq .type .shape } " )
398
395
mvn_seq = pt .specify_shape (mvn_seq , mus .type .shape )
399
396
400
397
# Move time axis back to position -2 so batches are on the left
401
398
if mvn_seq .ndim > 2 :
402
399
mvn_seq = pt .moveaxis (mvn_seq , 0 , - 2 )
403
- print (f"mvn_seq.type.shape: { mvn_seq .type .shape } " )
404
400
405
401
(seq_mvn_rng ,) = tuple (updates .values ())
406
402
407
- print (f"mvn_seq.type.shape: { mvn_seq .type .shape } " )
408
-
409
403
return [seq_mvn_rng , mvn_seq ]
410
404
411
405
mvn_seq_op = KalmanFilterRV (
412
406
inputs = [mus_ , covs_ , logp_ , rng ], outputs = recursion (mus_ , covs_ , rng ), ndim_supp = 2
413
407
)
414
408
415
409
mvn_seq = mvn_seq_op (mus , covs , logp , rng )
416
- print (f"mvn_seq.type.shape: { mvn_seq .type .shape } " )
417
410
return mvn_seq
418
411
419
412
420
413
@_logprob .register (KalmanFilterRV )
421
414
def sequence_mvnormal_logp (op , values , mus , covs , logp , rng , ** kwargs ):
422
- print (values [0 ].type .shape , mus .type .shape , covs .type .shape )
423
415
return check_parameters (
424
416
logp ,
425
417
pt .eq (values [0 ].shape [- 2 ], mus .shape [- 2 ]),
0 commit comments