Skip to content

Commit bba8dfe

Browse files
committed
Mask if swap_operands is True.
1 parent f09902e commit bba8dfe

File tree

2 files changed

+17
-20
lines changed

2 files changed

+17
-20
lines changed

cfpq_matrix/matrix_to_optimized_adapter.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,30 +37,24 @@ def to_unoptimized(self) -> Matrix:
3737
return self.base
3838

3939
def mxm(self, other: Matrix, op: Semiring, mask: Matrix, swap_operands: bool = False) -> Matrix:
40-
# return (
41-
# other.mxm(self.base, op)
42-
# if swap_operands
43-
# else self.base.mxm(other, op)
44-
#).new(self.dtype)
4540
if swap_operands:
4641
if not mask is None:
47-
#print("Mask applied, swap operands")
48-
#result = Matrix(self.dtype,nrows=other.shape[0],ncols=self.shape[1])
42+
#print("It would be nice to apply mask in swap operands")
43+
print("Mask applied, swap operands.")
44+
result = Matrix(self.dtype,nrows=other.shape[0],ncols=self.shape[1])
4945
#mask_t = Matrix(mask.dtype, ncols=mask.ncols, nrows=mask.nrows)
5046
#mask_t << mask.T
5147
#result(~mask) << other.mxm(self.base, op)
52-
#result(~mask) << other.mxm(self.base, op).new(self.dtype)
53-
#return result
54-
return other.mxm(self.base, op).new(self.dtype)
48+
result(~mask) << other.mxm(self.base, op).new(self.dtype)
49+
return result
50+
#return other.mxm(self.base, op).new(self.dtype)
5551
else: return other.mxm(self.base, op).new(self.dtype)
5652
else:
5753
if not mask is None:
58-
print("Mask applied")
54+
print("Mask applied.")
5955
result = Matrix(self.dtype,nrows=self.shape[0],ncols=other.shape[1])
60-
#result(~mask) << self.base.mxm(other, op)
6156
result(~mask) << self.base.mxm(other, op).new(self.dtype)
6257
return result
63-
#return self.base.mxm(other, op).new(self.dtype)
6458
else: return self.base.mxm(other, op).new(self.dtype)
6559

6660

cfpq_model/label_decomposed_graph.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -262,16 +262,19 @@ def mxm(
262262
if swap_operands:
263263
rhs1, rhs2 = rhs2, rhs1
264264
if rhs1 in self.matrices and rhs2 in other.matrices:
265-
#if lhs in self.matrices:
266-
# print ("!!! Mask: ")
267-
# print(self.matrices[lhs])
268-
# print("-------------------")
269-
# print(self.matrices[rhs1].shape)
270-
# print(other.matrices[rhs2].shape)
265+
if swap_operands:
266+
mask = (accum.matrices[lhs].to_mask()
267+
if lhs in accum.matrices and other.matrices[rhs2].shape == self.matrices[rhs1].shape
268+
else None)
269+
else:
270+
mask = (self.matrices[lhs].to_mask()
271+
if lhs in self.matrices and other.matrices[rhs2].shape == self.matrices[rhs1].shape
272+
else None)
273+
271274
mxm = self.matrices[rhs1].mxm(
272275
other.matrices[rhs2],
273276
swap_operands=swap_operands,
274-
mask=(self.matrices[lhs].to_mask() if lhs in self.matrices and other.matrices[rhs2].shape == self.matrices[rhs1].shape else None), #Matrix(dtype=self.matrices[rhs1].dtype, nrows=self.matrices[rhs1].nrows, ncols = self.matrices[rhs2].ncols)),
277+
mask=mask,
275278
op=op,
276279
)
277280
accum.iadd_by_symbol(lhs, mxm, op.monoid)

0 commit comments

Comments
 (0)