@@ -35,8 +35,6 @@ def __init__(self, base: OptimizedMatrix, block_matrix_space: BlockMatrixSpace):
3535 def mxm (self , other : Matrix , op : Semiring , mask :Matrix , swap_operands : bool = False ) -> Matrix :
3636 if self .block_matrix_space .is_single_cell (other .shape ):
3737 return self .base .mxm (other , op , mask , swap_operands = swap_operands )
38- if not mask is None :
39- mask = self .block_matrix_space .repeat_into_hyper_column (mask )
4038 return self .base .mxm (
4139 self .block_matrix_space .hyper_rotate (
4240 other ,
@@ -45,7 +43,7 @@ def mxm(self, other: Matrix, op: Semiring, mask:Matrix, swap_operands: bool = Fa
4543 else BlockMatrixOrientation .HORIZONTAL
4644 ),
4745 op = op ,
48- mask = mask ,
46+ mask = None ,
4947 swap_operands = swap_operands ,
5048 )
5149
@@ -90,22 +88,20 @@ def _force_init_orientation(
9088 return self .matrices [desired_orientation ]
9189
9290 def mxm (self , other : Matrix , op : Semiring , mask :Matrix , swap_operands : bool = False ) -> Matrix :
93- if self .block_matrix_space .is_single_cell (other .shape ):
94- if not mask is None :
95- mask = self .block_matrix_space .hyper_rotate (self .block_matrix_space .repeat_into_hyper_column (mask ),BlockMatrixOrientation .HORIZONTAL )
91+ if self .block_matrix_space .is_single_cell (other .shape ):
9692 return self ._force_init_orientation (
9793 BlockMatrixOrientation .HORIZONTAL
9894 if swap_operands
9995 else BlockMatrixOrientation .VERTICAL
100- ).mxm (other , op , mask , swap_operands = swap_operands )
96+ ).mxm (other , op , None , swap_operands = swap_operands )
10197 return self ._force_init_orientation (
10298 BlockMatrixOrientation .VERTICAL
10399 if swap_operands
104100 else BlockMatrixOrientation .HORIZONTAL
105101 ).mxm (
106102 self .block_matrix_space .to_block_diag_matrix (other ),
107103 op = op ,
108- mask = mask ,
104+ mask = None ,
109105 swap_operands = swap_operands
110106 )
111107
0 commit comments