@@ -618,16 +618,36 @@ def join_tensors_by_dim_labels(
618
618
# Check for no overlap first. In this case, do a block_diagonal join, which implicitly results in padding zeros
619
619
# everywhere they are needed -- no other sorting or padding necessary
620
620
if combined_labels == [* labels , * other_labels ]:
621
- return pt .linalg .block_diag (tensor , other_tensor )
621
+ res = pt .linalg .block_diag (tensor , other_tensor )
622
+ new_shape = [
623
+ shape_1 + shape_2 if (shape_1 is not None and shape_2 is not None ) else None
624
+ for shape_1 , shape_2 in zip (tensor .type .shape , other_tensor .type .shape )
625
+ ]
626
+ return pt .specify_shape (res , new_shape )
622
627
623
628
# Otherwise there is either total overlap or partial overlap. Let the padding and reordering function figure it out.
624
629
tensor = ndim_pad_and_reorder (tensor , labels , combined_labels , labeled_axis )
625
630
other_tensor = ndim_pad_and_reorder (other_tensor , other_labels , combined_labels , labeled_axis )
626
631
627
632
if block_diag_join :
628
- return pt .linalg .block_diag (tensor , other_tensor )
633
+ new_shape = [
634
+ shape_1 + shape_2 if (shape_1 is not None and shape_2 is not None ) else None
635
+ for shape_1 , shape_2 in zip (tensor .type .shape , other_tensor .type .shape )
636
+ ]
637
+ res = pt .linalg .block_diag (tensor , other_tensor )
629
638
else :
630
- return pt .concatenate ([tensor , other_tensor ], axis = join_axis )
639
+ new_shape = []
640
+ join_axis_norm = normalize_axis (tensor , join_axis )
641
+ for i , (shape_1 , shape_2 ) in enumerate (zip (tensor .type .shape , other_tensor .type .shape )):
642
+ if i == join_axis_norm :
643
+ new_shape .append (
644
+ shape_1 + shape_2 if (shape_1 is not None and shape_2 is not None ) else None
645
+ )
646
+ else :
647
+ new_shape .append (shape_1 if shape_1 is not None else shape_2 )
648
+ res = pt .concatenate ([tensor , other_tensor ], axis = join_axis )
649
+
650
+ return pt .specify_shape (res , new_shape )
631
651
632
652
633
653
def get_exog_dims_from_idata (exog_name , idata ):
0 commit comments