Skip to content

Commit c65e543

Browse files
dance858Transurgeon
authored andcommitted
2d A in converter
1 parent 98fc202 commit c65e543

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

cvxpy/reductions/solvers/nlp_solvers/diff_engine/converters.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def normalize_shape(shape):
3232
shape = tuple(shape)
3333
return (1,) * (2 - len(shape)) + shape
3434

35-
3635
def _chain_add(children):
3736
"""Chain multiple children with binary adds: a + b + c -> add(add(a, b), c)."""
3837
result = children[0]
@@ -61,11 +60,12 @@ def _convert_matmul(expr, children):
6160
A.shape[1],
6261
)
6362
else:
63+
m, n = normalize_shape(A.shape)
6464
return _diffengine.make_dense_left_matmul(
6565
children[1],
6666
A.flatten(order='C'),
67-
A.shape[0],
68-
A.shape[1],
67+
m,
68+
n,
6969
)
7070
elif right_arg.is_constant():
7171
A = right_arg.value
@@ -83,11 +83,12 @@ def _convert_matmul(expr, children):
8383
A.shape[1],
8484
)
8585
else:
86+
m, n = normalize_shape(A.shape)
8687
return _diffengine.make_dense_right_matmul(
8788
children[0],
8889
A.flatten(order='C'),
89-
A.shape[0],
90-
A.shape[1],
90+
m,
91+
n,
9192
)
9293
else:
9394
return _diffengine.make_matmul(children[0], children[1])

0 commit comments

Comments
 (0)