Skip to content

Commit 4e41cdd

Browse files
committed
Try to rotate mask. Select lhs matrix with maximal nnz as a mask.
1 parent 0e347a3 commit 4e41cdd

File tree

2 files changed

+28
-12
lines changed

2 files changed

+28
-12
lines changed

cfpq_matrix/block/block_matrix.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,14 @@ def __init__(self, base: OptimizedMatrix, block_matrix_space: BlockMatrixSpace):
3232
assert block_matrix_space.is_single_cell(base.shape)
3333
super().__init__(base, block_matrix_space)
3434

35-
def mxm(self, other: Matrix, op: Semiring, mask:Matrix, swap_operands: bool = False) -> Matrix:
35+
def mxm(self, other: Matrix, op: Semiring, mask: Matrix, swap_operands: bool = False) -> Matrix:
36+
print ("Self shape")
37+
print (self.shape)
38+
print ("Other shape")
39+
print (other.shape)
3640
if self.block_matrix_space.is_single_cell(other.shape):
37-
return self.base.mxm(other, op, mask, swap_operands=swap_operands)
38-
return self.base.mxm(
41+
return self.base.mxm(other, op, mask, swap_operands=swap_operands)
42+
return self.base.mxm(
3943
self.block_matrix_space.hyper_rotate(
4044
other,
4145
BlockMatrixOrientation.VERTICAL
@@ -94,14 +98,19 @@ def mxm(self, other: Matrix, op: Semiring, mask:Matrix, swap_operands: bool = Fa
9498
if swap_operands
9599
else BlockMatrixOrientation.VERTICAL
96100
).mxm(other, op, None, swap_operands=swap_operands)
101+
mask = (self.block_matrix_space.hyper_rotate(
102+
mask,
103+
BlockMatrixOrientation.VERTICAL
104+
if swap_operands
105+
else BlockMatrixOrientation.HORIZONTAL)) if not mask is None else mask
97106
return self._force_init_orientation(
98107
BlockMatrixOrientation.VERTICAL
99108
if swap_operands
100109
else BlockMatrixOrientation.HORIZONTAL
101110
).mxm(
102111
self.block_matrix_space.to_block_diag_matrix(other),
103112
op=op,
104-
mask=None,
113+
mask=mask,
105114
swap_operands=swap_operands
106115
)
107116

cfpq_model/label_decomposed_graph.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -262,14 +262,21 @@ def mxm(
262262
if swap_operands:
263263
rhs1, rhs2 = rhs2, rhs1
264264
if rhs1 in self.matrices and rhs2 in other.matrices:
265-
#if swap_operands:
266-
mask = (accum.matrices[lhs].to_mask()
267-
if lhs in accum.matrices and other.matrices[rhs2].shape == self.matrices[rhs1].shape
268-
else None)
269-
#else:
270-
# mask = (self.matrices[lhs].to_mask()
271-
# if lhs in self.matrices and other.matrices[rhs2].shape == self.matrices[rhs1].shape
272-
# else None)
265+
266+
if lhs in accum.matrices:
267+
print("Mask shape")
268+
print(accum.matrices[lhs].shape)
269+
270+
matrix_for_mask = max(((accum.matrices[lhs]
271+
if lhs in accum.matrices
272+
else None),
273+
(self.matrices[lhs]
274+
if lhs in self.matrices
275+
else None),
276+
(other.matrices[lhs]
277+
if lhs in other.matrices
278+
else None)), key = lambda m: m.nvals if not m is None else -1)
279+
mask = matrix_for_mask if isinstance(matrix_for_mask, Matrix) or matrix_for_mask is None else matrix_for_mask.to_mask()
273280

274281
mxm = self.matrices[rhs1].mxm(
275282
other.matrices[rhs2],

0 commit comments

Comments
 (0)