@@ -37,30 +37,24 @@ def to_unoptimized(self) -> Matrix:
3737 return self .base
3838
3939 def mxm (self , other : Matrix , op : Semiring , mask : Matrix , swap_operands : bool = False ) -> Matrix :
40- # return (
41- # other.mxm(self.base, op)
42- # if swap_operands
43- # else self.base.mxm(other, op)
44- #).new(self.dtype)
4540 if swap_operands :
4641 if not mask is None :
47- #print("Mask applied, swap operands")
48- #result = Matrix(self.dtype,nrows=other.shape[0],ncols=self.shape[1])
42+ #print("It would be nice to apply mask in swap operands")
43+ print ("Mask applied, swap operands." )
44+ result = Matrix (self .dtype ,nrows = other .shape [0 ],ncols = self .shape [1 ])
4945 #mask_t = Matrix(mask.dtype, ncols=mask.ncols, nrows=mask.nrows)
5046 #mask_t << mask.T
5147 #result(~mask) << other.mxm(self.base, op)
52- # result(~mask) << other.mxm(self.base, op).new(self.dtype)
53- # return result
54- return other .mxm (self .base , op ).new (self .dtype )
48+ result (~ mask ) << other .mxm (self .base , op ).new (self .dtype )
49+ return result
50+ # return other.mxm(self.base, op).new(self.dtype)
5551 else : return other .mxm (self .base , op ).new (self .dtype )
5652 else :
5753 if not mask is None :
58- print ("Mask applied" )
54+ print ("Mask applied. " )
5955 result = Matrix (self .dtype ,nrows = self .shape [0 ],ncols = other .shape [1 ])
60- #result(~mask) << self.base.mxm(other, op)
6156 result (~ mask ) << self .base .mxm (other , op ).new (self .dtype )
6257 return result
63- #return self.base.mxm(other, op).new(self.dtype)
6458 else : return self .base .mxm (other , op ).new (self .dtype )
6559
6660
0 commit comments