Skip to content

Commit f9a0108

Browse files
authored
fixing all not-implemented hessian issues (#45)
* fixing all not-implemented hessian issues * making some tests pass with hessian_approx=exact * delete the remove laters --------- Co-authored-by: William Zijie Zhang <william@gridmatic.com>
1 parent 044ae0c commit f9a0108

File tree

15 files changed

+166
-227
lines changed

15 files changed

+166
-227
lines changed

cvxpy/atoms/affine/binary_operators.py

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from cvxpy.atoms.affine.sum import sum as cvxpy_sum
3333
from cvxpy.constraints.constraint import Constraint
3434
from cvxpy.error import DCPError
35-
from cvxpy.expressions.constants.constant import Constant
3635
from cvxpy.expressions.constants.parameter import (
3736
is_param_affine,
3837
is_param_free,
@@ -235,33 +234,60 @@ def _grad(self, values):
235234

236235
return [DX, DY]
237236

238-
def _hess(self, values):
239-
"""Compute the Hessian of elementwise multiplication w.r.t. each argument.
237+
def _verify_hess_vec_args(self):
238+
x = self.args[0]
239+
y = self.args[1]
240+
if x.size != y.size:
241+
return False
240242

241-
For z = x * y (elementwise), returns:
242-
- d2z/dx2 = diag(y)
243-
- d2z/dy2 = diag(x)
243+
if x.is_constant() and y.is_constant():
244+
return False
244245

245-
Args:
246-
values: A list of numeric values for the arguments [x, y].
246+
# one of the following must be true:
247+
# 1. both arguments are variables
248+
# 2. one argument is a constant
249+
# 3. one argument is a Promote of a variable and the other is a variable
250+
both_are_variables = isinstance(x, Variable) and isinstance(y, Variable)
251+
one_is_constant = x.is_constant() or y.is_constant()
252+
x_is_promote = type(x) == Promote and isinstance(y, Variable)
253+
y_is_promote = type(y) == Promote and isinstance(x, Variable)
247254

248-
Returns:
249-
A list of SciPy CSC sparse matrices [D2X, D2Y].
250-
"""
251-
if isinstance(self.args[0], Variable) and isinstance(self.args[1], Variable):
252-
return {(self.args[0], self.args[1]): np.eye(self.size),
253-
(self.args[1], self.args[0]): np.eye(self.size)}
254-
if isinstance(self.args[0], Constant) and isinstance(self.args[1], Variable):
255-
return self.args[1].hess
256-
x = values[0]
257-
y = values[1]
258-
# what is the hessian of elementwise multiplication?
259-
# Flatten in case inputs are not 1D
260-
x = np.asarray(x).flatten(order='F')
261-
y = np.asarray(y).flatten(order='F')
262-
D2X = sp.diags(y, format='csc')
263-
D2Y = sp.diags(x, format='csc')
264-
return [D2X, D2Y]
255+
if not (both_are_variables or one_is_constant or x_is_promote or y_is_promote):
256+
return False
257+
258+
if both_are_variables and x.id == y.id:
259+
return False
260+
261+
return True
262+
263+
def _hess_vec(self, vec):
264+
x = self.args[0]
265+
y = self.args[1]
266+
267+
# constant * atom
268+
if x.is_constant():
269+
y_hess_vec = y.hess_vec(x.value * vec)
270+
return y_hess_vec
271+
272+
# atom * constant
273+
if y.is_constant():
274+
x_hess_vec = x.hess_vec(y.value * vec)
275+
return x_hess_vec
276+
277+
# x * y with x a scalar variable, y a vector variable
278+
if not isinstance(x, Variable) and x.is_affine():
279+
assert(type(x) == Promote)
280+
x_var = x.args[0] # here x is a Promote because of how we canonicalize
281+
return {(x_var, y): vec, (y, x_var): vec}
282+
283+
# x * y with x a vector variable, y a scalar
284+
if not isinstance(y, Variable) and y.is_affine():
285+
assert(type(y) == Promote)
286+
y_var = y.args[0] # here y is a Promote because of how we canonicalize
287+
return {(x, y_var): vec, (y_var, x): vec}
288+
289+
# if we arrive here both arguments are variables of the same size
290+
return {(x, y): np.diag(vec), (y, x): np.diag(vec)}
265291

266292
def graph_implementation(
267293
self, arg_objs, shape: Tuple[int, ...], data=None

cvxpy/atoms/elementwise/abs.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,17 @@ def _grad(self, values):
8383
D += (values[0] > 0)
8484
D -= (values[0] < 0)
8585
return [abs.elemwise_grad_to_diag(D, rows, cols)]
86+
87+
def _verify_hess_vec_args(self):
88+
return True
89+
90+
def _hess_vec(self, vec):
91+
"""
92+
Computes the Hessian-vector product dictionary
93+
for the abs atom. We assume that the argument will be a variable.
94+
"""
95+
raise NotImplementedError("Second derivative of abs is not implemented yet.")
96+
hess_dict = {}
97+
var = self.args[0]
98+
hess_dict[(var, var)] = np.diag(np.sign(var.value) * vec)
99+
return hess_dict

cvxpy/atoms/elementwise/kl_div.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from cvxpy.atoms.elementwise.elementwise import Elementwise
2424
from cvxpy.constraints.constraint import Constraint
25+
from cvxpy.expressions.variable import Variable
2526

2627

2728
class kl_div(Elementwise):
@@ -90,6 +91,14 @@ def _grad(self, values) -> List[Optional[csc_array]]:
9091
rows, cols)]
9192
return grad_list
9293

94+
def _verify_hess_vec_args(self):
95+
return isinstance(self.args[0], Variable)
96+
97+
def _hess_vec(self, vec):
98+
""" See the docstring of the hess_vec method of the atom class. """
99+
x = self.args[0]
100+
return {(x, x): np.diag(-vec / x.value)}
101+
93102
def _domain(self) -> List[Constraint]:
94103
"""Returns constraints describing the domain of the node.
95104
"""

cvxpy/atoms/elementwise/maximum.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,10 @@ def _grad(self, values) -> List[Any]:
109109
grad_list += [maximum.elemwise_grad_to_diag(grad_vals,
110110
rows, cols)]
111111
return grad_list
112+
113+
def _verify_hess_vec_args(self):
114+
return True
115+
116+
def _hess_vec(self, vec):
117+
"""See the docstring of the hess_vec method of the atom class."""
118+
raise NotImplementedError("Second derivative of maximum is not implemented yet.")

cvxpy/atoms/elementwise/minimum.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,10 @@ def _grad(self, values) -> List[Any]:
101101
grad_list += [minimum.elemwise_grad_to_diag(grad_vals,
102102
rows, cols)]
103103
return grad_list
104+
105+
def _verify_hess_vec_args(self):
106+
return True
107+
108+
def _hess_vec(self, vec):
109+
""" See the docstring of the hess_vec method of the atom class. """
110+
raise NotImplementedError("Second derivative of minimum is not implemented yet.")

cvxpy/atoms/elementwise/trig.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,18 @@ def _grad(self, values) -> List[Constraint]:
8383
grad_vals = np.cos(values[0])
8484
return [sin.elemwise_grad_to_diag(grad_vals, rows, cols)]
8585

86+
def _verify_hess_vec_args(self):
87+
return True
88+
89+
def _hess_vec(self, vec):
90+
"""
91+
Computes the Hessian-vector product dictionary
92+
for the sin atom. We assume that the argument will be a variable.
93+
"""
94+
hess_dict = {}
95+
var = self.args[0]
96+
hess_dict[(var, var)] = np.diag(-np.sin(var.value) * vec)
97+
return hess_dict
8698

8799
class cos(Elementwise):
88100
"""Elementwise :math:`\\cos x`.
@@ -145,6 +157,19 @@ def _grad(self, values) -> List[Constraint]:
145157
cols = self.size
146158
grad_vals = -np.sin(values[0])
147159
return [cos.elemwise_grad_to_diag(grad_vals, rows, cols)]
160+
161+
def _verify_hess_vec_args(self):
162+
return True
163+
164+
def _hess_vec(self, vec):
165+
"""
166+
Computes the Hessian-vector product dictionary
167+
for the cos atom. We assume that the argument will be a variable.
168+
"""
169+
hess_dict = {}
170+
var = self.args[0]
171+
hess_dict[(var, var)] = np.diag(-np.cos(var.value) * vec)
172+
return hess_dict
148173

149174

150175
class tan(Elementwise):
@@ -208,3 +233,16 @@ def _grad(self, values) -> List[Constraint]:
208233
cols = self.size
209234
grad_vals = 1/np.cos(values[0])**2
210235
return [tan.elemwise_grad_to_diag(grad_vals, rows, cols)]
236+
237+
def _verify_hess_vec_args(self):
238+
return True
239+
240+
def _hess_vec(self, vec):
241+
"""
242+
Computes the Hessian-vector product dictionary
243+
for the tan atom. We assume that the argument will be a variable.
244+
"""
245+
hess_dict = {}
246+
var = self.args[0]
247+
hess_dict[(var, var)] = np.diag(2*np.tan(var.value)/np.cos(var.value)**2 * vec)
248+
return hess_dict

cvxpy/atoms/pnorm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,3 +270,10 @@ def _column_grad(self, value):
270270
nominator = np.power(value, exp)
271271
frac = np.divide(nominator, denominator)
272272
return np.reshape(frac, (frac.size, 1))
273+
274+
def _verify_hess_vec_args(self):
275+
return True
276+
277+
def _hess_vec(self, vec):
278+
""" See the docstring of the hess_vec method of the atom class. """
279+
raise NotImplementedError("Second derivative of p-norm is not implemented yet.")

cvxpy/atoms/quad_form.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,21 @@ def _grad(self, values):
126126
D = (P + np.conj(P.T)) @ x
127127
return [sp.csc_array([D.ravel(order="F")]).T]
128128

129-
def _hess(self, values):
129+
def _verify_hess_vec_args(self):
130+
return True
131+
132+
def _hess_vec(self, vec):
130133
"""
131-
The hessian of a quadratic form x.T @ Q @ x
132-
with respect to x, is the constant matrix Q.
134+
Computes the Hessian-vector product dictionary
135+
for a quadratic form. We assume that the quad-form will be
136+
canonicalized to w.T @ Q @ w, where w is a single variable
137+
and Q is a constant matrix.
133138
"""
134-
var = self.variables()[0]
135-
return {(var, var): 2 * np.array(values[1])}
139+
hess_dict = {}
140+
var = self.args[0]
141+
Q = self.args[1]
142+
hess_dict[(var, var)] = vec * 2 * Q.value
143+
return hess_dict
136144

137145
def shape_from_args(self) -> Tuple[int, ...]:
138146
return tuple()
@@ -174,6 +182,22 @@ def sign_from_args(self) -> Tuple[bool, bool]:
174182
def is_quadratic(self) -> bool:
175183
return True
176184

185+
def _verify_hess_vec_args(self):
186+
return True
187+
188+
def _hess_vec(self, vec):
189+
"""
190+
Computes the Hessian-vector product dictionary
191+
for a quadratic form. We assume that the quad-form will be
192+
canonicalized to w.T @ Q @ w, where w is a single variable
193+
and Q is a constant matrix.
194+
"""
195+
hess_dict = {}
196+
var = self.args[0]
197+
Q = self.args[1]
198+
hess_dict[(var, var)] = vec * 2 * Q.value
199+
return hess_dict
200+
177201

178202
def decomp_quad(P, cond=None, rcond=None, lower=True, check_finite: bool = True):
179203
"""

cvxpy/tests/NLP_tests/remove_later.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

0 commit comments

Comments
 (0)