File tree Expand file tree Collapse file tree 1 file changed +6
-5
lines changed
cvxpy/reductions/solvers/nlp_solvers/diff_engine Expand file tree Collapse file tree 1 file changed +6
-5
lines changed Original file line number Diff line number Diff line change @@ -32,7 +32,6 @@ def normalize_shape(shape):
3232 shape = tuple (shape )
3333 return (1 ,) * (2 - len (shape )) + shape
3434
35-
3635def _chain_add (children ):
3736 """Chain multiple children with binary adds: a + b + c -> add(add(a, b), c)."""
3837 result = children [0 ]
@@ -61,11 +60,12 @@ def _convert_matmul(expr, children):
6160 A .shape [1 ],
6261 )
6362 else :
63+ m , n = normalize_shape (A .shape )
6464 return _diffengine .make_dense_left_matmul (
6565 children [1 ],
6666 A .flatten (order = 'C' ),
67- A . shape [ 0 ] ,
68- A . shape [ 1 ] ,
67+ m ,
68+ n ,
6969 )
7070
7171 elif right_arg .is_constant ():
@@ -84,11 +84,12 @@ def _convert_matmul(expr, children):
8484 A .shape [1 ],
8585 )
8686 else :
87+ m , n = normalize_shape (A .shape )
8788 return _diffengine .make_dense_right_matmul (
8889 children [0 ],
8990 A .flatten (order = 'C' ),
90- A . shape [ 0 ] ,
91- A . shape [ 1 ] ,
91+ m ,
92+ n ,
9293 )
9394 else :
9495 return _diffengine .make_matmul (children [0 ], children [1 ])
You can’t perform that action at this time.
0 commit comments