@@ -32,10 +32,14 @@ def __init__(self, base: OptimizedMatrix, block_matrix_space: BlockMatrixSpace):
3232 assert block_matrix_space .is_single_cell (base .shape )
3333 super ().__init__ (base , block_matrix_space )
3434
35- def mxm (self , other : Matrix , op : Semiring , mask :Matrix , swap_operands : bool = False ) -> Matrix :
35+ def mxm (self , other : Matrix , op : Semiring , mask : Matrix , swap_operands : bool = False ) -> Matrix :
36+ print ("Self shape" )
37+ print (self .shape )
38+ print ("Other shape" )
39+ print (other .shape )
3640 if self .block_matrix_space .is_single_cell (other .shape ):
37- return self .base .mxm (other , op , mask , swap_operands = swap_operands )
38- return self .base .mxm (
41+ return self .base .mxm (other , op , mask , swap_operands = swap_operands )
42+ return self .base .mxm (
3943 self .block_matrix_space .hyper_rotate (
4044 other ,
4145 BlockMatrixOrientation .VERTICAL
@@ -94,14 +98,19 @@ def mxm(self, other: Matrix, op: Semiring, mask:Matrix, swap_operands: bool = Fa
9498 if swap_operands
9599 else BlockMatrixOrientation .VERTICAL
96100 ).mxm (other , op , None , swap_operands = swap_operands )
101+ mask = (self .block_matrix_space .hyper_rotate (
102+ mask ,
103+ BlockMatrixOrientation .VERTICAL
104+ if swap_operands
105+ else BlockMatrixOrientation .HORIZONTAL )) if not mask is None else mask
97106 return self ._force_init_orientation (
98107 BlockMatrixOrientation .VERTICAL
99108 if swap_operands
100109 else BlockMatrixOrientation .HORIZONTAL
101110 ).mxm (
102111 self .block_matrix_space .to_block_diag_matrix (other ),
103112 op = op ,
104- mask = None ,
113+ mask = mask ,
105114 swap_operands = swap_operands
106115 )
107116
0 commit comments