33import numpy as np
44from numpy .testing import assert_array_equal , assert_array_almost_equal
55from pylops import MatrixMult , Identity
6+
7+ import pyproximal
68from pyproximal .utils import moreau
7- from pyproximal .proximal import Quadratic , Nonlinear , \
8- L1 , L2 , Orthogonal , VStack
9+ from pyproximal .proximal import L1 , L2 , Nonlinear , Orthogonal , Quadratic , \
10+ SingularValuePenalty , VStack
911
1012par1 = {'nx' : 10 , 'sigma' : 1. , 'dtype' : 'float32' } # even float32
1113par2 = {'nx' : 11 , 'sigma' : 2. , 'dtype' : 'float64' } # odd float64
1214
13- np .random .seed (10 )
14-
1515
1616@pytest .mark .parametrize ("par" , [(par1 ), (par2 )])
1717def test_Quadratic (par ):
1818 """Quadratic functional and proximal/dual proximal
1919 """
20+ np .random .seed (10 )
2021 A = np .random .normal (0 , 1 , (par ['nx' ], par ['nx' ]))
2122 A = A .T @ A
2223 quad = Quadratic (Op = MatrixMult (A ), b = np .ones (par ['nx' ]), niter = 500 )
@@ -31,6 +32,7 @@ def test_Quadratic(par):
3132def test_DotProduct (par ):
3233 """Dot product functional and proximal/dual proximal
3334 """
35+ np .random .seed (10 )
3436 quad = Quadratic (b = np .ones (par ['nx' ]))
3537
3638 # prox / dualprox
@@ -43,6 +45,7 @@ def test_DotProduct(par):
4345def test_Constant (par ):
4446 """Constant functional and proximal/dual proximal
4547 """
48+ np .random .seed (10 )
4649 quad = Quadratic (c = 5. )
4750
4851 # prox / dualprox
@@ -55,6 +58,7 @@ def test_Constant(par):
5558def test_SemiOrthogonal (par ):
5659 """L1 functional with Semi-Orthogonal operator and proximal/dual proximal
5760 """
61+ np .random .seed (10 )
5862 l1 = L1 ()
5963 orth = Orthogonal (l1 , 2 * Identity (par ['nx' ]), b = np .arange (par ['nx' ]),
6064 partial = True , alpha = 4. )
@@ -69,6 +73,7 @@ def test_SemiOrthogonal(par):
6973def test_Orthogonal (par ):
7074 """L1 functional with Orthogonal operator and proximal/dual proximal
7175 """
76+ np .random .seed (10 )
7277 l1 = L1 ()
7378 orth = Orthogonal (l1 , Identity (par ['nx' ]), b = np .arange (par ['nx' ]))
7479
@@ -82,6 +87,7 @@ def test_Orthogonal(par):
8287def test_VStack (par ):
8388 """L2 functional with VStack operator of multiple L1s
8489 """
90+ np .random .seed (10 )
8591 nxs = [par ['nx' ] // 4 ] * 4
8692 nxs [- 1 ] = par ['nx' ] - np .sum (nxs [:- 1 ])
8793 l2 = L2 ()
@@ -106,6 +112,7 @@ def test_Nonlinear():
106112 """Nonlinear proximal operator. Since this is a template class simply check
107113 that errors are raised when not used properly
108114 """
115+ np .random .seed (10 )
109116 Nop = Nonlinear (np .ones (10 ))
110117 with pytest .raises (NotImplementedError ):
111118 Nop .fun (np .ones (10 ))
@@ -115,4 +122,20 @@ def test_Nonlinear():
115122 Nop .optimize ()
116123
117124
125+ @pytest .mark .parametrize ("par" , [(par1 ), (par2 )])
126+ def test_SingularValuePenalty (par ):
127+ """Test SingularValuePenalty
128+ """
129+ np .random .seed (10 )
130+ f_mu = pyproximal .QuadraticEnvelopeCard (mu = par ['sigma' ])
131+ penalty = SingularValuePenalty ((par ['nx' ], 2 * par ['nx' ]), f_mu )
132+
133+ # norm, cross-check with svd (use tolerance as two methods don't provide
134+ # the exact same eigenvalues)
135+ X = np .random .uniform (0. , 0.1 , (par ['nx' ], 2 * par ['nx' ])).astype (par ['dtype' ])
136+ _ , S , _ = np .linalg .svd (X )
137+ assert (penalty (X .ravel ()) - f_mu (S )) < 1e-3
118138
139+ # prox / dualprox
140+ tau = 0.75
141+ assert moreau (penalty , X .ravel (), tau )
0 commit comments