Skip to content

Commit b439d5e

Browse files
committed
2d A in converter
1 parent b387fbb commit b439d5e

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

cvxpy/reductions/solvers/nlp_solvers/diff_engine/converters.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def normalize_shape(shape):
3232
shape = tuple(shape)
3333
return (1,) * (2 - len(shape)) + shape
3434

35-
3635
def _chain_add(children):
3736
"""Chain multiple children with binary adds: a + b + c -> add(add(a, b), c)."""
3837
result = children[0]
@@ -61,11 +60,12 @@ def _convert_matmul(expr, children):
6160
A.shape[1],
6261
)
6362
else:
63+
m, n = normalize_shape(A.shape)
6464
return _diffengine.make_dense_left_matmul(
6565
children[1],
6666
A.flatten(order='C'),
67-
A.shape[0],
68-
A.shape[1],
67+
m,
68+
n,
6969
)
7070

7171
elif right_arg.is_constant():
@@ -84,11 +84,12 @@ def _convert_matmul(expr, children):
8484
A.shape[1],
8585
)
8686
else:
87+
m, n = normalize_shape(A.shape)
8788
return _diffengine.make_dense_right_matmul(
8889
children[0],
8990
A.flatten(order='C'),
90-
A.shape[0],
91-
A.shape[1],
91+
m,
92+
n,
9293
)
9394
else:
9495
return _diffengine.make_matmul(children[0], children[1])

0 commit comments

Comments
 (0)