Skip to content

Commit 46d7c05

Browse files
authored
fix duplicate index in hessian of indexing (#112)
1 parent 88044a9 commit 46d7c05

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

cvxpy/atoms/affine/index.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _hess_vec(self, vec):
121121
"""See the docstring of the hess_vec method of the atom class. """
122122
idx = self._orig_key
123123
e = np.zeros(self.args[0].size)
124-
e[idx] = vec
124+
np.add.at(e, np.atleast_1d(idx), vec)
125125
return self.args[0].hess_vec(e)
126126

127127
def _jacobian(self):
@@ -255,7 +255,7 @@ def _hess_vec(self, vec):
255255
""" See the docstring of the hess_vec method of the atom class. """
256256
idx = np.reshape(self._select_mat, self._select_mat.size, order='F')
257257
e = np.zeros(self.args[0].size)
258-
e[idx] = vec
258+
np.add.at(e, np.atleast_1d(idx), vec)
259259
return self.args[0].hess_vec(e)
260260

261261
def _verify_jacobian_args(self):

cvxpy/tests/NLP_tests/hess_tests/test_hess_idx.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
11
import numpy as np
2-
import pytest
32

43
import cvxpy as cp
54

65

76
class TestHessIndex():
87

8+
def test_scalar_idx(self):
9+
x = cp.Variable((1,), name='x')
10+
x.value = np.array([3.0])
11+
vec = np.array(4)
12+
log2 = cp.log(x)[0]
13+
result_dict = log2.hess_vec(vec)
14+
correct_matrix = 4 * (-np.diag(np.array([1/9])))
15+
computed_hess = np.zeros((1, 1))
16+
rows, cols, vals = result_dict[(x, x)]
17+
computed_hess[rows, cols] = vals
18+
assert(np.allclose(computed_hess, correct_matrix))
19+
920
def test_single_idx(self):
1021
n = 3
11-
x = cp.Variable((n, ), name='x')
22+
x = cp.Variable((n,), name='x')
1223
x.value = np.array([1.0, 2.0, 3.0])
1324
vec = np.array([4])
1425
log2 = cp.log(x)[2]
@@ -21,7 +32,7 @@ def test_single_idx(self):
2132

2233
def test_slice_two_idx(self):
2334
n = 3
24-
x = cp.Variable((n, ), name='x')
35+
x = cp.Variable((n,), name='x')
2536
x.value = np.array([1.0, 2.0, 3.0])
2637
vec = np.array([2, 4])
2738
idxs = np.array([1, 2])
@@ -36,7 +47,7 @@ def test_slice_two_idx(self):
3647

3748
def test_slice_two_other_idx(self):
3849
n = 3
39-
x = cp.Variable((n, ), name='x')
50+
x = cp.Variable((n,), name='x')
4051
x.value = np.array([1.5, 2.0, 3.0])
4152
vec = np.array([2, 4])
4253
idxs = np.array([0, 2])
@@ -67,10 +78,10 @@ def test_special_index_matrix(self):
6778
computed_hess[rows, cols] = vals
6879
assert(np.allclose(computed_hess, correct_matrix))
6980

70-
@pytest.mark.skip(reason="TODO fix this test for duplicate indices")
7181
def test_special_index_duplicate_matrix(self):
7282
"""
73-
TODO fix this test
83+
This test was failing because hess_vec didn't properly handle
84+
duplicate indices.
7485
"""
7586
x = cp.Variable((2, 2), name='x')
7687
x.value = np.array([[1.0, 2.0], [3.0, 4.0]])

0 commit comments

Comments
 (0)