Skip to content

Commit a102e3c

Browse files
Propagate static shape information in join_tensors_by_dim_labels where possible
1 parent 7581f04 commit a102e3c

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

pymc_extras/statespace/models/utilities.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -618,16 +618,36 @@ def join_tensors_by_dim_labels(
618618
# Check for no overlap first. In this case, do a block_diagonal join, which implicitly results in padding zeros
619619
# everywhere they are needed -- no other sorting or padding necessary
620620
if combined_labels == [*labels, *other_labels]:
621-
return pt.linalg.block_diag(tensor, other_tensor)
621+
res = pt.linalg.block_diag(tensor, other_tensor)
622+
new_shape = [
623+
shape_1 + shape_2 if (shape_1 is not None and shape_2 is not None) else None
624+
for shape_1, shape_2 in zip(tensor.type.shape, other_tensor.type.shape)
625+
]
626+
return pt.specify_shape(res, new_shape)
622627

623628
# Otherwise there is either total overlap or partial overlap. Let the padding and reordering function figure it out.
624629
tensor = ndim_pad_and_reorder(tensor, labels, combined_labels, labeled_axis)
625630
other_tensor = ndim_pad_and_reorder(other_tensor, other_labels, combined_labels, labeled_axis)
626631

627632
if block_diag_join:
628-
return pt.linalg.block_diag(tensor, other_tensor)
633+
new_shape = [
634+
shape_1 + shape_2 if (shape_1 is not None and shape_2 is not None) else None
635+
for shape_1, shape_2 in zip(tensor.type.shape, other_tensor.type.shape)
636+
]
637+
res = pt.linalg.block_diag(tensor, other_tensor)
629638
else:
630-
return pt.concatenate([tensor, other_tensor], axis=join_axis)
639+
new_shape = []
640+
join_axis_norm = normalize_axis(tensor, join_axis)
641+
for i, (shape_1, shape_2) in enumerate(zip(tensor.type.shape, other_tensor.type.shape)):
642+
if i == join_axis_norm:
643+
new_shape.append(
644+
shape_1 + shape_2 if (shape_1 is not None and shape_2 is not None) else None
645+
)
646+
else:
647+
new_shape.append(shape_1 if shape_1 is not None else shape_2)
648+
res = pt.concatenate([tensor, other_tensor], axis=join_axis)
649+
650+
return pt.specify_shape(res, new_shape)
631651

632652

633653
def get_exog_dims_from_idata(exog_name, idata):

0 commit comments

Comments
 (0)