@@ -618,16 +618,36 @@ def join_tensors_by_dim_labels(
618618 # Check for no overlap first. In this case, do a block_diagonal join, which implicitly results in padding zeros
619619 # everywhere they are needed -- no other sorting or padding necessary
620620 if combined_labels == [* labels , * other_labels ]:
621- return pt .linalg .block_diag (tensor , other_tensor )
621+ res = pt .linalg .block_diag (tensor , other_tensor )
622+ new_shape = [
623+ shape_1 + shape_2 if (shape_1 is not None and shape_2 is not None ) else None
624+ for shape_1 , shape_2 in zip (tensor .type .shape , other_tensor .type .shape )
625+ ]
626+ return pt .specify_shape (res , new_shape )
622627
623628 # Otherwise there is either total overlap or partial overlap. Let the padding and reordering function figure it out.
624629 tensor = ndim_pad_and_reorder (tensor , labels , combined_labels , labeled_axis )
625630 other_tensor = ndim_pad_and_reorder (other_tensor , other_labels , combined_labels , labeled_axis )
626631
627632 if block_diag_join :
628- return pt .linalg .block_diag (tensor , other_tensor )
633+ new_shape = [
634+ shape_1 + shape_2 if (shape_1 is not None and shape_2 is not None ) else None
635+ for shape_1 , shape_2 in zip (tensor .type .shape , other_tensor .type .shape )
636+ ]
637+ res = pt .linalg .block_diag (tensor , other_tensor )
629638 else :
630- return pt .concatenate ([tensor , other_tensor ], axis = join_axis )
639+ new_shape = []
640+ join_axis_norm = normalize_axis (tensor , join_axis )
641+ for i , (shape_1 , shape_2 ) in enumerate (zip (tensor .type .shape , other_tensor .type .shape )):
642+ if i == join_axis_norm :
643+ new_shape .append (
644+ shape_1 + shape_2 if (shape_1 is not None and shape_2 is not None ) else None
645+ )
646+ else :
647+ new_shape .append (shape_1 if shape_1 is not None else shape_2 )
648+ res = pt .concatenate ([tensor , other_tensor ], axis = join_axis )
649+
650+ return pt .specify_shape (res , new_shape )
631651
632652
633653def get_exog_dims_from_idata (exog_name , idata ):
0 commit comments