Skip to content

Commit b16501b

Browse files
authored
Adds prod and hstack/vstack and fixes an issue with the newest version of NumPy (#120)
* Adds DNLP cp.prod * Adds hstack/vstack * Fixes failing tests * Fixes hallucination * Addresses Daniel's comments * Removes unneeded code * Addresses Daniel's comments * Fixes constraints in broken test problem
1 parent eb728ff commit b16501b

File tree

10 files changed

+1306
-1
lines changed

10 files changed

+1306
-1
lines changed

cvxpy/atoms/affine/hstack.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import List, Tuple
1717

1818
import numpy as np
19+
from scipy.sparse import coo_matrix
1920

2021
import cvxpy.lin_ops.lin_op as lo
2122
import cvxpy.lin_ops.lin_utils as lu
@@ -83,3 +84,65 @@ def graph_implementation(
8384
(LinOp for objective, list of constraints)
8485
"""
8586
return (lu.hstack(arg_objs, shape), [])
87+
88+
def _verify_jacobian_args(self):
89+
return True
90+
91+
def _jacobian(self):
92+
result = {}
93+
94+
flat_offset = 0
95+
for arg in self.args:
96+
jac = arg.jacobian()
97+
98+
for k, (rows, cols, vals) in jac.items():
99+
new_rows = rows + flat_offset
100+
if k in result:
101+
old_rows, old_cols, old_vals = result[k]
102+
result[k] = (
103+
np.concatenate([old_rows, new_rows]),
104+
np.concatenate([old_cols, cols]),
105+
np.concatenate([old_vals, vals]),
106+
)
107+
else:
108+
result[k] = (new_rows, cols, vals)
109+
110+
flat_offset += arg.size
111+
112+
return result
113+
114+
def _verify_hess_vec_args(self):
115+
return True
116+
117+
def _hess_vec(self, vec):
118+
result = {}
119+
keys_require_summing = []
120+
121+
flat_offset = 0
122+
for arg in self.args:
123+
arg_vec = vec[flat_offset:flat_offset + arg.size]
124+
125+
arg_result = arg.hess_vec(arg_vec)
126+
for k, v in arg_result.items():
127+
if k in result:
128+
old_rows, old_cols, old_vals = result[k]
129+
new_rows, new_cols, new_vals = v
130+
result[k] = (
131+
np.concatenate([old_rows, new_rows]),
132+
np.concatenate([old_cols, new_cols]),
133+
np.concatenate([old_vals, new_vals]),
134+
)
135+
keys_require_summing.append(k)
136+
else:
137+
result[k] = v
138+
139+
flat_offset += arg.size
140+
141+
for k in set(keys_require_summing):
142+
rows, cols, vals = result[k]
143+
var1, var2 = k
144+
hess = coo_matrix((vals, (rows, cols)), shape=(var1.size, var2.size))
145+
hess.sum_duplicates()
146+
result[k] = (hess.row, hess.col, hess.data)
147+
148+
return result

cvxpy/atoms/affine/vstack.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import List, Tuple
1717

1818
import numpy as np
19+
from scipy.sparse import coo_matrix
1920

2021
import cvxpy.lin_ops.lin_op as lo
2122
import cvxpy.lin_ops.lin_utils as lu
@@ -79,3 +80,70 @@ def graph_implementation(
7980
(LinOp for objective, list of constraints)
8081
"""
8182
return (lu.vstack(arg_objs, shape), [])
83+
84+
def _verify_jacobian_args(self):
85+
return True
86+
87+
def _jacobian(self):
88+
result = {}
89+
M = self.shape[0]
90+
91+
row_offset = 0
92+
for arg in self.args:
93+
jac = arg.jacobian()
94+
m_j = arg.shape[0] if arg.ndim >= 2 else 1
95+
for k, (rows, cols, vals) in jac.items():
96+
new_rows = (rows % m_j) + row_offset + (rows // m_j) * M
97+
if k in result:
98+
old_rows, old_cols, old_vals = result[k]
99+
result[k] = (
100+
np.concatenate([old_rows, new_rows]),
101+
np.concatenate([old_cols, cols]),
102+
np.concatenate([old_vals, vals]),
103+
)
104+
else:
105+
result[k] = (new_rows, cols, vals)
106+
row_offset += m_j
107+
108+
return result
109+
110+
def _verify_hess_vec_args(self):
111+
return True
112+
113+
def _hess_vec(self, vec):
114+
M = self.shape[0]
115+
result = {}
116+
keys_require_summing = []
117+
118+
row_offset = 0
119+
for arg in self.args:
120+
m_j = arg.shape[0] if arg.ndim >= 2 else 1
121+
122+
arg_indices = np.arange(arg.size)
123+
output_indices = (arg_indices % m_j) + row_offset + (arg_indices // m_j) * M
124+
arg_vec = vec[output_indices]
125+
126+
arg_result = arg.hess_vec(arg_vec)
127+
for k, v in arg_result.items():
128+
if k in result:
129+
old_rows, old_cols, old_vals = result[k]
130+
new_rows, new_cols, new_vals = v
131+
result[k] = (
132+
np.concatenate([old_rows, new_rows]),
133+
np.concatenate([old_cols, new_cols]),
134+
np.concatenate([old_vals, new_vals]),
135+
)
136+
keys_require_summing.append(k)
137+
else:
138+
result[k] = v
139+
140+
row_offset += m_j
141+
142+
for k in set(keys_require_summing):
143+
rows, cols, vals = result[k]
144+
var1, var2 = k
145+
hess = coo_matrix((vals, (rows, cols)), shape=(var1.size, var2.size))
146+
hess.sum_duplicates()
147+
result[k] = (hess.row, hess.col, hess.data)
148+
149+
return result

cvxpy/atoms/prod.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import cvxpy.interface as intf
2222
from cvxpy.atoms.affine.hstack import hstack
2323
from cvxpy.atoms.axis_atom import AxisAtom
24+
from cvxpy.expressions.variable import Variable
2425

2526

2627
class Prod(AxisAtom):
@@ -70,6 +71,16 @@ def is_atom_log_log_concave(self) -> bool:
7071
"""
7172
return True
7273

74+
def is_atom_esr(self) -> bool:
75+
"""Is the atom ESR (epigraph smooth representable)?
76+
"""
77+
return True
78+
79+
def is_atom_hsr(self) -> bool:
80+
"""Is the atom HSR (hypograph smooth representable)?
81+
"""
82+
return True
83+
7384
def is_incr(self, idx) -> bool:
7485
"""Is the composition non-decreasing in argument idx?
7586
"""
@@ -133,6 +144,216 @@ def _grad(self, values):
133144
"""
134145
return self._axis_grad(values)
135146

147+
def _verify_jacobian_args(self):
148+
return isinstance(self.args[0], Variable)
149+
150+
def _input_to_output_indices(self, in_shape):
151+
"""
152+
Map each flattened input index to its corresponding output index.
153+
154+
For axis reduction, each input element contributes to exactly one output.
155+
Returns array of length prod(in_shape) with output index for each input.
156+
"""
157+
n_in = int(np.prod(in_shape))
158+
in_indices = np.arange(n_in)
159+
in_multi = np.array(np.unravel_index(in_indices, in_shape, order='F')).T
160+
161+
if self.keepdims:
162+
out_multi = in_multi.copy()
163+
out_multi[:, self.axis] = 0
164+
out_shape = np.array(in_shape)
165+
out_shape[self.axis] = 1
166+
else:
167+
out_multi = np.delete(in_multi, self.axis, axis=1)
168+
out_shape = np.delete(np.array(in_shape), self.axis)
169+
170+
if len(out_shape) == 0:
171+
return np.zeros(n_in, dtype=int)
172+
return np.ravel_multi_index(out_multi.T, out_shape, order='F')
173+
174+
@staticmethod
175+
def _prod_except_self(arr):
176+
"""
177+
Compute the product of all elements except each element itself.
178+
179+
For arr = [a, b, c, d], returns [b*c*d, a*c*d, a*b*d, a*b*c].
180+
181+
Uses prefix and suffix products to avoid division and handle zeros.
182+
Fully vectorized using cumulative products.
183+
"""
184+
n = len(arr)
185+
if n == 0:
186+
return np.array([])
187+
if n == 1:
188+
return np.array([1.0])
189+
190+
# prefix[i] = arr[0] * arr[1] * ... * arr[i-1]
191+
prefix = np.empty(n)
192+
prefix[0] = 1.0
193+
prefix[1:] = np.cumprod(arr[:-1])
194+
195+
# suffix[i] = arr[i+1] * arr[i+2] * ... * arr[n-1]
196+
suffix = np.empty(n)
197+
suffix[-1] = 1.0
198+
suffix[:-1] = np.cumprod(arr[::-1])[:-1][::-1]
199+
200+
return prefix * suffix
201+
202+
@staticmethod
203+
def _prod_except_pairs(arr):
204+
"""
205+
Compute the product of all elements except each pair (i, j) for i != j.
206+
207+
Returns an n x n matrix H where H[i,j] = prod of arr except indices i and j.
208+
Diagonal entries (i == i) are set to 0.
209+
210+
Returns None if all entries would be zero (3+ zeros in arr).
211+
212+
Handles zeros correctly without division.
213+
"""
214+
n = len(arr)
215+
if n == 0:
216+
return None
217+
if n == 1:
218+
return None
219+
220+
zero_mask = (arr == 0)
221+
num_zeros = np.sum(zero_mask)
222+
223+
if num_zeros >= 3:
224+
# Three or more zeros: all products are zero
225+
return None
226+
227+
if num_zeros == 0:
228+
# No zeros: use prod(arr) / (arr[i] * arr[j])
229+
total_prod = np.prod(arr)
230+
H = total_prod / np.outer(arr, arr)
231+
np.fill_diagonal(H, 0.0)
232+
elif num_zeros == 1:
233+
# One zero at index k: only H[k, j] and H[j, k] are nonzero for j != k
234+
k = np.where(zero_mask)[0][0]
235+
prod_nonzero = np.prod(arr[~zero_mask])
236+
H = np.zeros((n, n))
237+
# H[k, j] = prod_nonzero / arr[j] for j != k
238+
nonzero_mask = ~zero_mask
239+
H[k, nonzero_mask] = prod_nonzero / arr[nonzero_mask]
240+
H[nonzero_mask, k] = H[k, nonzero_mask]
241+
else: # num_zeros == 2
242+
# Two zeros: only H[k1, k2] and H[k2, k1] are nonzero
243+
zero_indices = np.where(zero_mask)[0]
244+
k1, k2 = zero_indices[0], zero_indices[1]
245+
prod_nonzero = np.prod(arr[~zero_mask])
246+
H = np.zeros((n, n))
247+
H[k1, k2] = prod_nonzero
248+
H[k2, k1] = prod_nonzero
249+
250+
return H
251+
252+
def _jacobian(self):
253+
"""
254+
The jacobian of prod(x) with respect to x.
255+
256+
For prod(x) = x_1 * x_2 * ... * x_n:
257+
∂prod(x)/∂x_i = prod_{j != i}(x_j)
258+
259+
Uses prefix/suffix products to handle zeros correctly.
260+
"""
261+
x = self.args[0]
262+
x_val = x.value
263+
n_in = x.size
264+
col_idxs = np.arange(n_in, dtype=int)
265+
266+
if self.axis is None:
267+
grad_vals = self._prod_except_self(x_val.flatten(order='F'))
268+
row_idxs = np.zeros(n_in, dtype=int)
269+
else:
270+
grad_vals = np.apply_along_axis(
271+
self._prod_except_self, self.axis, x_val
272+
).flatten(order='F')
273+
row_idxs = self._input_to_output_indices(x.shape)
274+
275+
return {x: (row_idxs, col_idxs, grad_vals)}
276+
277+
def _verify_hess_vec_args(self):
278+
return isinstance(self.args[0], Variable)
279+
280+
def _hess_vec(self, vec):
281+
"""
282+
Compute weighted sum of Hessians for prod(x).
283+
284+
vec has size equal to the output dimension of prod.
285+
For axis=None, output is scalar, so vec has size 1.
286+
For axis != None, vec has size equal to prod of non-reduced dimensions.
287+
288+
The Hessian of prod for each output component has:
289+
H[i,j] = prod_{k != i,j}(x_k) for i != j (among inputs to that component)
290+
H[i,i] = 0
291+
292+
Returns weighted combination: sum_k vec[k] * H_k
293+
"""
294+
x = self.args[0]
295+
x_val = x.value
296+
n_in = x.size
297+
empty = (np.array([], dtype=int), np.array([], dtype=int), np.array([]))
298+
299+
if self.axis is None:
300+
H = self._prod_except_pairs(x_val.flatten(order='F'))
301+
if H is None:
302+
return {(x, x): empty}
303+
H = vec[0] * H
304+
row_idxs, col_idxs = np.meshgrid(
305+
np.arange(n_in), np.arange(n_in), indexing='ij'
306+
)
307+
return {(x, x): (row_idxs.ravel(), col_idxs.ravel(), H.ravel())}
308+
309+
# Multiple outputs: vec[k] weights the k-th output's Hessian
310+
in_indices = np.arange(n_in)
311+
out_idxs = self._input_to_output_indices(x.shape)
312+
axis_positions = np.unravel_index(in_indices, x.shape, order='F')[self.axis]
313+
n_out = len(np.unique(out_idxs))
314+
x_flat = x_val.flatten(order='F')
315+
316+
all_rows = []
317+
all_cols = []
318+
all_vals = []
319+
320+
for out_idx in range(n_out):
321+
mask = (out_idxs == out_idx)
322+
local_in_indices = in_indices[mask]
323+
324+
# Sort by axis position to align with _prod_except_pairs
325+
sort_order = np.argsort(axis_positions[mask])
326+
sorted_in_indices = local_in_indices[sort_order]
327+
328+
H_local = self._prod_except_pairs(x_flat[sorted_in_indices])
329+
if H_local is None:
330+
continue
331+
332+
H_local = vec[out_idx] * H_local
333+
334+
# Build global index pairs using meshgrid
335+
m = len(sorted_in_indices)
336+
local_rows, local_cols = np.meshgrid(
337+
np.arange(m), np.arange(m), indexing='ij'
338+
)
339+
global_rows = sorted_in_indices[local_rows.ravel()]
340+
global_cols = sorted_in_indices[local_cols.ravel()]
341+
vals = H_local.ravel()
342+
343+
# Filter out zeros
344+
nonzero = vals != 0
345+
all_rows.append(global_rows[nonzero])
346+
all_cols.append(global_cols[nonzero])
347+
all_vals.append(vals[nonzero])
348+
349+
if all_rows:
350+
return {(x, x): (
351+
np.concatenate(all_rows),
352+
np.concatenate(all_cols),
353+
np.concatenate(all_vals)
354+
)}
355+
return {(x, x): empty}
356+
136357

137358
def prod(expr, axis=None, keepdims: bool = False) -> Prod:
138359
"""Multiply the entries of an expression.

0 commit comments

Comments
 (0)