Skip to content

Commit 9fbc3ae

Browse files
dance858Transurgeon
authored andcommitted
add some tests and sketch new converter
1 parent 6301f44 commit 9fbc3ae

File tree

2 files changed

+79
-23
lines changed

2 files changed

+79
-23
lines changed

cvxpy/reductions/solvers/nlp_solvers/diff_engine/converters.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -48,31 +48,47 @@ def _convert_matmul(expr, children):
4848
if left_arg.is_constant():
4949
A = left_arg.value
5050

51-
if not isinstance(A, sparse.csr_matrix):
52-
A = sparse.csr_matrix(A)
53-
54-
return _diffengine.make_left_matmul(
55-
children[1],
56-
A.data.astype(np.float64),
57-
A.indices.astype(np.int32),
58-
A.indptr.astype(np.int32),
59-
A.shape[0],
60-
A.shape[1],
61-
)
51+
if sparse.issparse(A) or True:
52+
if not isinstance(A, sparse.csr_matrix):
53+
A = sparse.csr_matrix(A)
54+
55+
return _diffengine.make_sparse_left_matmul(
56+
children[1],
57+
A.data.astype(np.float64, copy=False),
58+
A.indices.astype(np.int32, copy=False),
59+
A.indptr.astype(np.int32, copy=False),
60+
A.shape[0],
61+
A.shape[1],
62+
)
63+
else:
64+
return _diffengine.make_dense_left_matmul(
65+
children[1],
66+
A.flatten(order='F'),
67+
A.shape[0],
68+
A.shape[1],
69+
)
6270
elif right_arg.is_constant():
6371
A = right_arg.value
64-
65-
if not isinstance(A, sparse.csr_matrix):
66-
A = sparse.csr_matrix(A)
67-
68-
return _diffengine.make_right_matmul(
69-
children[0],
70-
A.data.astype(np.float64),
71-
A.indices.astype(np.int32),
72-
A.indptr.astype(np.int32),
73-
A.shape[0],
74-
A.shape[1],
75-
)
72+
73+
if sparse.issparse(A) or True:
74+
if not isinstance(A, sparse.csr_matrix):
75+
A = sparse.csr_matrix(A)
76+
77+
return _diffengine.make_sparse_right_matmul(
78+
children[0],
79+
A.data.astype(np.float64, copy=False),
80+
A.indices.astype(np.int32, copy=False),
81+
A.indptr.astype(np.int32, copy=False),
82+
A.shape[0],
83+
A.shape[1],
84+
)
85+
else:
86+
return _diffengine.make_dense_right_matmul(
87+
children[0],
88+
A.flatten(order='F'),
89+
A.shape[0],
90+
A.shape[1],
91+
)
7692
else:
7793
return _diffengine.make_matmul(children[0], children[1])
7894

cvxpy/tests/nlp_tests/stress_tests_diff_engine/test_matmul_sparse.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,43 @@ def test_dense_sparse_sparse(self):
7373
assert np.allclose(dense_val, csc_val)
7474
assert np.allclose(dense_sol, sparse_sol)
7575
assert np.allclose(dense_sol, csc_sol)
76+
77+
def test_dense_left_matmul(self):
78+
np.random.seed(0)
79+
m, n = 4, 4
80+
A = np.random.rand(m, n)
81+
X = cp.Variable((n, n), nonneg=True)
82+
B = np.random.rand(m, n)
83+
obj = cp.Minimize(cp.sum_squares(A @ X - B))
84+
constraints = []
85+
problem = cp.Problem(obj, constraints)
86+
problem.solve(nlp=True, verbose=True)
87+
checker = DerivativeChecker(problem)
88+
checker.run_and_assert()
89+
90+
def test_dense_right_matmul(self):
91+
np.random.seed(0)
92+
m, n = 4, 4
93+
A = np.random.rand(m, n)
94+
X = cp.Variable((n, n), nonneg=True)
95+
B = np.random.rand(m, n)
96+
obj = cp.Minimize(cp.sum_squares(X @ A - B))
97+
constraints = []
98+
problem = cp.Problem(obj, constraints)
99+
problem.solve(nlp=True, verbose=True)
100+
checker = DerivativeChecker(problem)
101+
checker.run_and_assert()
102+
103+
def test_sparse_and_dense_matmul(self):
104+
np.random.seed(0)
105+
m, n = 4, 4
106+
A = np.random.rand(m, n)
107+
C = sp.random(m, n, density=0.5)
108+
X = cp.Variable((n, n), nonneg=True)
109+
B = np.random.rand(m, n)
110+
obj = cp.Minimize(cp.sum_squares(A @ X @ C - B))
111+
constraints = []
112+
problem = cp.Problem(obj, constraints)
113+
problem.solve(nlp=True, verbose=True)
114+
checker = DerivativeChecker(problem)
115+
checker.run_and_assert()

0 commit comments

Comments
 (0)