Skip to content

Commit f8e7729

Browse files
More static shapes
1 parent c15f965 commit f8e7729

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

pymc_extras/statespace/models/structural/core.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,13 @@ def make_slice(name, x, o_x):
551551
obs_intercept.name = d.name
552552

553553
transition = pt.linalg.block_diag(T, o_T)
554+
transition = pt.specify_shape(
555+
transition,
556+
shape=[
557+
sum(shapes) if not any([s is None for s in shapes]) else None
558+
for shapes in zip(*[T.type.shape, o_T.type.shape])
559+
],
560+
)
554561
transition.name = T.name
555562

556563
design = join_tensors_by_dim_labels(
@@ -563,6 +570,13 @@ def make_slice(name, x, o_x):
563570
design.name = Z.name
564571

565572
selection = pt.linalg.block_diag(R, o_R)
573+
selection = pt.specify_shape(
574+
selection,
575+
shape=[
576+
sum(shapes) if not any([s is None for s in shapes]) else None
577+
for shapes in zip(*[R.type.shape, o_R.type.shape])
578+
],
579+
)
566580
selection.name = R.name
567581

568582
obs_cov = add_tensors_by_dim_labels(

pymc_extras/statespace/models/utilities.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,8 @@ def reorder_from_labels(
428428
if indices.tolist() != list(range(n_out)):
429429
for axis in labeled_axis:
430430
idx = np.s_[tuple([slice(None, None) if i != axis else indices for i in range(x.ndim)])]
431-
x = x[idx]
431+
shape = x.type.shape
432+
x = pt.specify_shape(x[idx], shape)
432433

433434
return x
434435

@@ -511,7 +512,11 @@ def ndim_pad_and_reorder(
511512

512513
if n_missing > 0:
513514
pad_size = [(0, 0) if i not in labeled_axis else (0, n_missing) for i in range(x.ndim)]
514-
x = pt.pad(x, pad_size, mode="constant", constant_values=0)
515+
new_shape = [
516+
shape + sum(size) if shape is not None else None
517+
for shape, size in zip(x.type.shape, pad_size)
518+
]
519+
x = pt.specify_shape(pt.pad(x, pad_size, mode="constant", constant_values=0), new_shape)
515520

516521
return reorder_from_labels(x, labels, ordered_labels, labeled_axis)
517522

0 commit comments

Comments
 (0)