Skip to content

Commit e129758

Browse files
committed
equivalent treatment of (n, ) as numpy and cvxpy. Very subtle
1 parent 78c9e44 commit e129758

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

cvxpy/reductions/solvers/nlp_solvers/diff_engine/converters.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,9 @@
3636

3737
def _chain_add(children):
3838
"""Chain multiple children with binary adds: a + b + c -> add(add(a, b), c)."""
39-
print("start make add")
4039
result = children[0]
4140
for child in children[1:]:
4241
result = _diffengine.make_add(result, child)
43-
print("end make add")
4442
return result
4543

4644

@@ -201,6 +199,7 @@ def _convert_reshape(expr, children):
201199
)
202200

203201
d1, d2 = expr.shape
202+
# TODO: can it happen that len(expr.shape) < 2?
204203
return _diffengine.make_reshape(children[0], d1, d2)
205204

206205
def _convert_broadcast(expr, children):
@@ -213,6 +212,13 @@ def _convert_sum(expr, children):
213212
axis = -1
214213
return _diffengine.make_sum(children[0], axis)
215214

215+
def _convert_promote(expr, children):
216+
x_shape = tuple(expr.shape)
217+
x_shape = (1,) * (2 - len(x_shape)) + x_shape
218+
d1, d2 = x_shape
219+
220+
return _diffengine.make_promote(children[0], d1, d2)
221+
216222
# Mapping from CVXPY atom names to C diff engine functions
217223
# Converters receive (expr, children) where expr is the CVXPY expression
218224
ATOM_CONVERTERS = {
@@ -221,11 +227,7 @@ def _convert_sum(expr, children):
221227
"exp": lambda _expr, children: _diffengine.make_exp(children[0]),
222228
# Affine unary
223229
"NegExpression": lambda _expr, children: _diffengine.make_neg(children[0]),
224-
"Promote": lambda expr, children: _diffengine.make_promote(
225-
children[0],
226-
expr.shape[0] if len(expr.shape) >= 1 else 1,
227-
expr.shape[1] if len(expr.shape) >= 2 else 1,
228-
),
230+
"Promote": _convert_promote,
229231
# N-ary (handles 2+ args)
230232
"AddExpression": lambda _expr, children: _chain_add(children),
231233
# Reductions
@@ -288,7 +290,9 @@ def build_variable_dict(variables: list) -> tuple[dict, int]:
288290
if len(shape) == 2:
289291
d1, d2 = shape[0], shape[1]
290292
elif len(shape) == 1:
291-
d1, d2 = shape[0], 1
293+
# NuMPy and CVXPY broadcasting rules treat a (n, ) vector as (1, n),
294+
# not as (n, 1)
295+
d1, d2 = 1, shape[0]
292296
else: # scalar
293297
d1, d2 = 1, 1
294298
c_var = _diffengine.make_variable(d1, d2, offset, n_vars)
@@ -306,8 +310,10 @@ def convert_expr(expr, var_dict: dict, n_vars: int):
306310
# Base case: constant
307311
if isinstance(expr, cp.Constant):
308312
value = np.asarray(expr.value, dtype=np.float64).flatten(order='F')
309-
d1 = expr.shape[0] if len(expr.shape) >= 1 else 1
310-
d2 = expr.shape[1] if len(expr.shape) >= 2 else 1
313+
x_shape = tuple(expr.shape)
314+
x_shape = (1,) * (2 - len(x_shape)) + x_shape
315+
d1, d2 = x_shape
316+
311317
return _diffengine.make_constant(d1, d2, n_vars, value)
312318

313319
# Recursive case: atoms
@@ -319,8 +325,10 @@ def convert_expr(expr, var_dict: dict, n_vars: int):
319325

320326
# check that python dimension is consistent with C dimension
321327
d1_C, d2_C = _diffengine.get_expr_dimensions(C_expr)
322-
d1_Python = expr.shape[0] if len(expr.shape) >= 1 else 1
323-
d2_Python = expr.shape[1] if len(expr.shape) >= 2 else 1
328+
x_shape = tuple(expr.shape)
329+
x_shape = (1,) * (2 - len(x_shape)) + x_shape
330+
d1_Python, d2_Python = x_shape
331+
324332
if d1_C != d1_Python or d2_C != d2_Python:
325333
raise ValueError(
326334
f"Dimension mismatch for atom '{atom_name}': "

0 commit comments

Comments
 (0)