|
21 | 21 | import cvxpy.interface as intf |
22 | 22 | from cvxpy.atoms.affine.hstack import hstack |
23 | 23 | from cvxpy.atoms.axis_atom import AxisAtom |
| 24 | +from cvxpy.expressions.variable import Variable |
24 | 25 |
|
25 | 26 |
|
26 | 27 | class Prod(AxisAtom): |
@@ -70,6 +71,16 @@ def is_atom_log_log_concave(self) -> bool: |
70 | 71 | """ |
71 | 72 | return True |
72 | 73 |
|
| 74 | + def is_atom_esr(self) -> bool: |
| 75 | + """Is the atom ESR (epigraph smooth representable)? |
| 76 | + """ |
| 77 | + return True |
| 78 | + |
| 79 | + def is_atom_hsr(self) -> bool: |
| 80 | + """Is the atom HSR (hypograph smooth representable)? |
| 81 | + """ |
| 82 | + return True |
| 83 | + |
73 | 84 | def is_incr(self, idx) -> bool: |
74 | 85 | """Is the composition non-decreasing in argument idx? |
75 | 86 | """ |
@@ -133,6 +144,216 @@ def _grad(self, values): |
133 | 144 | """ |
134 | 145 | return self._axis_grad(values) |
135 | 146 |
|
| 147 | + def _verify_jacobian_args(self): |
| 148 | + return isinstance(self.args[0], Variable) |
| 149 | + |
| 150 | + def _input_to_output_indices(self, in_shape): |
| 151 | + """ |
| 152 | + Map each flattened input index to its corresponding output index. |
| 153 | +
|
| 154 | + For axis reduction, each input element contributes to exactly one output. |
| 155 | + Returns array of length prod(in_shape) with output index for each input. |
| 156 | + """ |
| 157 | + n_in = int(np.prod(in_shape)) |
| 158 | + in_indices = np.arange(n_in) |
| 159 | + in_multi = np.array(np.unravel_index(in_indices, in_shape, order='F')).T |
| 160 | + |
| 161 | + if self.keepdims: |
| 162 | + out_multi = in_multi.copy() |
| 163 | + out_multi[:, self.axis] = 0 |
| 164 | + out_shape = np.array(in_shape) |
| 165 | + out_shape[self.axis] = 1 |
| 166 | + else: |
| 167 | + out_multi = np.delete(in_multi, self.axis, axis=1) |
| 168 | + out_shape = np.delete(np.array(in_shape), self.axis) |
| 169 | + |
| 170 | + if len(out_shape) == 0: |
| 171 | + return np.zeros(n_in, dtype=int) |
| 172 | + return np.ravel_multi_index(out_multi.T, out_shape, order='F') |
| 173 | + |
| 174 | + @staticmethod |
| 175 | + def _prod_except_self(arr): |
| 176 | + """ |
| 177 | + Compute the product of all elements except each element itself. |
| 178 | +
|
| 179 | + For arr = [a, b, c, d], returns [b*c*d, a*c*d, a*b*d, a*b*c]. |
| 180 | +
|
| 181 | + Uses prefix and suffix products to avoid division and handle zeros. |
| 182 | + Fully vectorized using cumulative products. |
| 183 | + """ |
| 184 | + n = len(arr) |
| 185 | + if n == 0: |
| 186 | + return np.array([]) |
| 187 | + if n == 1: |
| 188 | + return np.array([1.0]) |
| 189 | + |
| 190 | + # prefix[i] = arr[0] * arr[1] * ... * arr[i-1] |
| 191 | + prefix = np.empty(n) |
| 192 | + prefix[0] = 1.0 |
| 193 | + prefix[1:] = np.cumprod(arr[:-1]) |
| 194 | + |
| 195 | + # suffix[i] = arr[i+1] * arr[i+2] * ... * arr[n-1] |
| 196 | + suffix = np.empty(n) |
| 197 | + suffix[-1] = 1.0 |
| 198 | + suffix[:-1] = np.cumprod(arr[::-1])[:-1][::-1] |
| 199 | + |
| 200 | + return prefix * suffix |
| 201 | + |
| 202 | + @staticmethod |
| 203 | + def _prod_except_pairs(arr): |
| 204 | + """ |
| 205 | + Compute the product of all elements except each pair (i, j) for i != j. |
| 206 | +
|
| 207 | + Returns an n x n matrix H where H[i,j] = prod of arr except indices i and j. |
| 208 | + Diagonal entries (i == i) are set to 0. |
| 209 | +
|
| 210 | + Returns None if all entries would be zero (3+ zeros in arr). |
| 211 | +
|
| 212 | + Handles zeros correctly without division. |
| 213 | + """ |
| 214 | + n = len(arr) |
| 215 | + if n == 0: |
| 216 | + return None |
| 217 | + if n == 1: |
| 218 | + return None |
| 219 | + |
| 220 | + zero_mask = (arr == 0) |
| 221 | + num_zeros = np.sum(zero_mask) |
| 222 | + |
| 223 | + if num_zeros >= 3: |
| 224 | + # Three or more zeros: all products are zero |
| 225 | + return None |
| 226 | + |
| 227 | + if num_zeros == 0: |
| 228 | + # No zeros: use prod(arr) / (arr[i] * arr[j]) |
| 229 | + total_prod = np.prod(arr) |
| 230 | + H = total_prod / np.outer(arr, arr) |
| 231 | + np.fill_diagonal(H, 0.0) |
| 232 | + elif num_zeros == 1: |
| 233 | + # One zero at index k: only H[k, j] and H[j, k] are nonzero for j != k |
| 234 | + k = np.where(zero_mask)[0][0] |
| 235 | + prod_nonzero = np.prod(arr[~zero_mask]) |
| 236 | + H = np.zeros((n, n)) |
| 237 | + # H[k, j] = prod_nonzero / arr[j] for j != k |
| 238 | + nonzero_mask = ~zero_mask |
| 239 | + H[k, nonzero_mask] = prod_nonzero / arr[nonzero_mask] |
| 240 | + H[nonzero_mask, k] = H[k, nonzero_mask] |
| 241 | + else: # num_zeros == 2 |
| 242 | + # Two zeros: only H[k1, k2] and H[k2, k1] are nonzero |
| 243 | + zero_indices = np.where(zero_mask)[0] |
| 244 | + k1, k2 = zero_indices[0], zero_indices[1] |
| 245 | + prod_nonzero = np.prod(arr[~zero_mask]) |
| 246 | + H = np.zeros((n, n)) |
| 247 | + H[k1, k2] = prod_nonzero |
| 248 | + H[k2, k1] = prod_nonzero |
| 249 | + |
| 250 | + return H |
| 251 | + |
| 252 | + def _jacobian(self): |
| 253 | + """ |
| 254 | + The jacobian of prod(x) with respect to x. |
| 255 | +
|
| 256 | + For prod(x) = x_1 * x_2 * ... * x_n: |
| 257 | + ∂prod(x)/∂x_i = prod_{j != i}(x_j) |
| 258 | +
|
| 259 | + Uses prefix/suffix products to handle zeros correctly. |
| 260 | + """ |
| 261 | + x = self.args[0] |
| 262 | + x_val = x.value |
| 263 | + n_in = x.size |
| 264 | + col_idxs = np.arange(n_in, dtype=int) |
| 265 | + |
| 266 | + if self.axis is None: |
| 267 | + grad_vals = self._prod_except_self(x_val.flatten(order='F')) |
| 268 | + row_idxs = np.zeros(n_in, dtype=int) |
| 269 | + else: |
| 270 | + grad_vals = np.apply_along_axis( |
| 271 | + self._prod_except_self, self.axis, x_val |
| 272 | + ).flatten(order='F') |
| 273 | + row_idxs = self._input_to_output_indices(x.shape) |
| 274 | + |
| 275 | + return {x: (row_idxs, col_idxs, grad_vals)} |
| 276 | + |
| 277 | + def _verify_hess_vec_args(self): |
| 278 | + return isinstance(self.args[0], Variable) |
| 279 | + |
| 280 | + def _hess_vec(self, vec): |
| 281 | + """ |
| 282 | + Compute weighted sum of Hessians for prod(x). |
| 283 | +
|
| 284 | + vec has size equal to the output dimension of prod. |
| 285 | + For axis=None, output is scalar, so vec has size 1. |
| 286 | + For axis != None, vec has size equal to prod of non-reduced dimensions. |
| 287 | +
|
| 288 | + The Hessian of prod for each output component has: |
| 289 | + H[i,j] = prod_{k != i,j}(x_k) for i != j (among inputs to that component) |
| 290 | + H[i,i] = 0 |
| 291 | +
|
| 292 | + Returns weighted combination: sum_k vec[k] * H_k |
| 293 | + """ |
| 294 | + x = self.args[0] |
| 295 | + x_val = x.value |
| 296 | + n_in = x.size |
| 297 | + empty = (np.array([], dtype=int), np.array([], dtype=int), np.array([])) |
| 298 | + |
| 299 | + if self.axis is None: |
| 300 | + H = self._prod_except_pairs(x_val.flatten(order='F')) |
| 301 | + if H is None: |
| 302 | + return {(x, x): empty} |
| 303 | + H = vec[0] * H |
| 304 | + row_idxs, col_idxs = np.meshgrid( |
| 305 | + np.arange(n_in), np.arange(n_in), indexing='ij' |
| 306 | + ) |
| 307 | + return {(x, x): (row_idxs.ravel(), col_idxs.ravel(), H.ravel())} |
| 308 | + |
| 309 | + # Multiple outputs: vec[k] weights the k-th output's Hessian |
| 310 | + in_indices = np.arange(n_in) |
| 311 | + out_idxs = self._input_to_output_indices(x.shape) |
| 312 | + axis_positions = np.unravel_index(in_indices, x.shape, order='F')[self.axis] |
| 313 | + n_out = len(np.unique(out_idxs)) |
| 314 | + x_flat = x_val.flatten(order='F') |
| 315 | + |
| 316 | + all_rows = [] |
| 317 | + all_cols = [] |
| 318 | + all_vals = [] |
| 319 | + |
| 320 | + for out_idx in range(n_out): |
| 321 | + mask = (out_idxs == out_idx) |
| 322 | + local_in_indices = in_indices[mask] |
| 323 | + |
| 324 | + # Sort by axis position to align with _prod_except_pairs |
| 325 | + sort_order = np.argsort(axis_positions[mask]) |
| 326 | + sorted_in_indices = local_in_indices[sort_order] |
| 327 | + |
| 328 | + H_local = self._prod_except_pairs(x_flat[sorted_in_indices]) |
| 329 | + if H_local is None: |
| 330 | + continue |
| 331 | + |
| 332 | + H_local = vec[out_idx] * H_local |
| 333 | + |
| 334 | + # Build global index pairs using meshgrid |
| 335 | + m = len(sorted_in_indices) |
| 336 | + local_rows, local_cols = np.meshgrid( |
| 337 | + np.arange(m), np.arange(m), indexing='ij' |
| 338 | + ) |
| 339 | + global_rows = sorted_in_indices[local_rows.ravel()] |
| 340 | + global_cols = sorted_in_indices[local_cols.ravel()] |
| 341 | + vals = H_local.ravel() |
| 342 | + |
| 343 | + # Filter out zeros |
| 344 | + nonzero = vals != 0 |
| 345 | + all_rows.append(global_rows[nonzero]) |
| 346 | + all_cols.append(global_cols[nonzero]) |
| 347 | + all_vals.append(vals[nonzero]) |
| 348 | + |
| 349 | + if all_rows: |
| 350 | + return {(x, x): ( |
| 351 | + np.concatenate(all_rows), |
| 352 | + np.concatenate(all_cols), |
| 353 | + np.concatenate(all_vals) |
| 354 | + )} |
| 355 | + return {(x, x): empty} |
| 356 | + |
136 | 357 |
|
137 | 358 | def prod(expr, axis=None, keepdims: bool = False) -> Prod: |
138 | 359 | """Multiply the entries of an expression. |
|
0 commit comments