Skip to content

Commit 94eeb7e

Browse files
committed
Debugging pytensor shape issue
1 parent 3f985ee commit 94eeb7e

File tree

2 files changed

+1
-9
lines changed

2 files changed

+1
-9
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,7 @@ def _insert_random_variables(self):
732732

733733
replacement_dict = {var: pymc_model[name] for name, var in self._name_to_variable.items()}
734734
self.subbed_ssm = vectorize_graph(matrices, replace=replacement_dict)
735-
for name, matrix in zip(LONG_MATRIX_NAMES, self.subbed_ssm):
735+
for name, matrix in zip(MATRIX_NAMES, self.subbed_ssm):
736736
matrix.name = name
737737

738738
def _insert_data_variables(self):

pymc_extras/statespace/filters/distributions.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,6 @@ def dist(cls, mus, covs, logp, **kwargs):
375375
def rv_op(cls, mus, covs, logp, size=None):
376376
# Batch dimensions (if any) will be on the far left, but scan requires time to be there instead
377377
mus_, covs_ = mus.type(), covs.type()
378-
print(f"mus_.type.shape: {mus_.type.shape}, covs_.type.shape: {covs_.type.shape}")
379378

380379
logp_ = logp.type()
381380
rng = pytensor.shared(np.random.default_rng())
@@ -385,7 +384,6 @@ def recursion(mus, covs, rng):
385384
mus = pt.moveaxis(mus, -2, 0)
386385
if covs.ndim > 3:
387386
covs = pt.moveaxis(covs, -3, 0)
388-
print(f"mus.type.shape: {mus.type.shape}, covs.type.shape: {covs.type.shape}")
389387

390388
def step(mu, cov, rng):
391389
new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method="svd").owner.outputs
@@ -394,32 +392,26 @@ def step(mu, cov, rng):
394392
mvn_seq, updates = pytensor.scan(
395393
step, sequences=[mus, covs], non_sequences=[rng], strict=True, n_steps=mus.shape[0]
396394
)
397-
print(f"mvn_seq.type.shape: {mvn_seq.type.shape}")
398395
mvn_seq = pt.specify_shape(mvn_seq, mus.type.shape)
399396

400397
# Move time axis back to position -2 so batches are on the left
401398
if mvn_seq.ndim > 2:
402399
mvn_seq = pt.moveaxis(mvn_seq, 0, -2)
403-
print(f"mvn_seq.type.shape: {mvn_seq.type.shape}")
404400

405401
(seq_mvn_rng,) = tuple(updates.values())
406402

407-
print(f"mvn_seq.type.shape: {mvn_seq.type.shape}")
408-
409403
return [seq_mvn_rng, mvn_seq]
410404

411405
mvn_seq_op = KalmanFilterRV(
412406
inputs=[mus_, covs_, logp_, rng], outputs=recursion(mus_, covs_, rng), ndim_supp=2
413407
)
414408

415409
mvn_seq = mvn_seq_op(mus, covs, logp, rng)
416-
print(f"mvn_seq.type.shape: {mvn_seq.type.shape}")
417410
return mvn_seq
418411

419412

420413
@_logprob.register(KalmanFilterRV)
421414
def sequence_mvnormal_logp(op, values, mus, covs, logp, rng, **kwargs):
422-
print(values[0].type.shape, mus.type.shape, covs.type.shape)
423415
return check_parameters(
424416
logp,
425417
pt.eq(values[0].shape[-2], mus.shape[-2]),

0 commit comments

Comments
 (0)