11import numpy as np
2- import pytest
32
43import cvxpy as cp
54
65
76class TestHessIndex ():
87
8+ def test_scalar_idx (self ):
9+ x = cp .Variable ((1 ,), name = 'x' )
10+ x .value = np .array ([3.0 ])
11+ vec = np .array (4 )
12+ log2 = cp .log (x )[0 ]
13+ result_dict = log2 .hess_vec (vec )
14+ correct_matrix = 4 * (- np .diag (np .array ([1 / 9 ])))
15+ computed_hess = np .zeros ((1 , 1 ))
16+ rows , cols , vals = result_dict [(x , x )]
17+ computed_hess [rows , cols ] = vals
18+ assert (np .allclose (computed_hess , correct_matrix ))
19+
920 def test_single_idx (self ):
1021 n = 3
11- x = cp .Variable ((n , ), name = 'x' )
22+ x = cp .Variable ((n ,), name = 'x' )
1223 x .value = np .array ([1.0 , 2.0 , 3.0 ])
1324 vec = np .array ([4 ])
1425 log2 = cp .log (x )[2 ]
@@ -21,7 +32,7 @@ def test_single_idx(self):
2132
2233 def test_slice_two_idx (self ):
2334 n = 3
24- x = cp .Variable ((n , ), name = 'x' )
35+ x = cp .Variable ((n ,), name = 'x' )
2536 x .value = np .array ([1.0 , 2.0 , 3.0 ])
2637 vec = np .array ([2 , 4 ])
2738 idxs = np .array ([1 , 2 ])
@@ -36,7 +47,7 @@ def test_slice_two_idx(self):
3647
3748 def test_slice_two_other_idx (self ):
3849 n = 3
39- x = cp .Variable ((n , ), name = 'x' )
50+ x = cp .Variable ((n ,), name = 'x' )
4051 x .value = np .array ([1.5 , 2.0 , 3.0 ])
4152 vec = np .array ([2 , 4 ])
4253 idxs = np .array ([0 , 2 ])
@@ -67,10 +78,10 @@ def test_special_index_matrix(self):
6778 computed_hess [rows , cols ] = vals
6879 assert (np .allclose (computed_hess , correct_matrix ))
6980
70- @pytest .mark .skip (reason = "TODO fix this test for duplicate indices" )
7181 def test_special_index_duplicate_matrix (self ):
7282 """
73- TODO fix this test
83+ This test was failing because hess_vec didn't properly handle
84+ duplicate indices.
7485 """
7586 x = cp .Variable ((2 , 2 ), name = 'x' )
7687 x .value = np .array ([[1.0 , 2.0 ], [3.0 , 4.0 ]])
0 commit comments