Skip to content

Commit 59c605b

Browse files
authored
Merge pull request #76 from marcusvaltonen/user/marcus/separable-singular-value-penalties
Implement generic singular value penalty
2 parents ffd376d + b4b6c3b commit 59c605b

File tree

4 files changed

+90
-5
lines changed

4 files changed

+90
-5
lines changed

docs/source/api/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ Operators
8080
ETP
8181
Geman
8282
QuadraticEnvelopeCard
83+
SingularValuePenalty
8384

8485

8586
Other operators
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import numpy as np
2+
3+
from pyproximal.ProxOperator import _check_tau
4+
from pyproximal import ProxOperator
5+
6+
7+
class SingularValuePenalty(ProxOperator):
8+
r"""Proximal operator of a penalty acting on the singular values.
9+
10+
Generic regularizer :math:`\mathcal{R}_f` acting on the singular values of a matrix,
11+
12+
.. math::
13+
14+
\mathcal{R}_f(\mathbf{X}) = f(\boldsymbol\lambda)
15+
16+
where :math:`\mathbf{X}` is a matrix of size :math:`M \times N` and
17+
:math:`\boldsymbol\lambda` is the corresponding singular value vector.
18+
19+
Parameters
20+
----------
21+
dim : :obj:`tuple`
22+
Size of matrix :math:`\mathbf{X}`.
23+
penalty : :class:`pyproximal.ProxOperator`
24+
Function acting on the singular values.
25+
26+
Notes
27+
-----
28+
The pyproximal implementation allows ``penalty`` to be any
29+
:class:`pyproximal.ProxOperator` acting on the singular values; however, not all
30+
penalties will result in a mathematically accurate proximal operator defined this
31+
way. Given a penalty :math:`f`, the proximal operator is assumed to be
32+
33+
.. math::
34+
35+
\prox_{\tau \mathcal{R}_f}(\mathbf{X}) =
36+
\mathbf{U} \diag\left( \prox_{\tau f}(\boldsymbol\lambda)\right) \mathbf{V}^H
37+
38+
where :math:`\mathbf{X} = \mathbf{U}\diag(\boldsymbol\lambda)\mathbf{V}^H`, is an
39+
SVD of :math:`\mathbf{X}`. It is the user's responsibility to check that this is
40+
true for their particular choice of ``penalty``.
41+
"""
42+
43+
def __init__(self, dim, penalty):
44+
super().__init__(None, False)
45+
self.dim = dim
46+
self.penalty = penalty
47+
48+
def __call__(self, x):
49+
X = x.reshape(self.dim)
50+
eigs = np.linalg.eigvalsh(X.T @ X)
51+
eigs[eigs < 0] = 0 # ensure all eigenvalues at positive
52+
return np.sum(self.penalty(np.sqrt(eigs)))
53+
54+
@_check_tau
55+
def prox(self, x, tau):
56+
X = x.reshape(self.dim)
57+
U, S, Vh = np.linalg.svd(X, full_matrices=False)
58+
X = np.dot(U * self.penalty.prox(S, tau), Vh)
59+
return X.ravel()

pyproximal/proximal/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
ETP Exponential-type penalty
3030
Geman Geman penalty
3131
QuadraticEnvelopeCard The quadratic envelope of the cardinality function
32+
SingularValuePenalty Generic singular value penalty
3233
3334
"""
3435

@@ -53,9 +54,10 @@
5354
from .ETP import *
5455
from .Geman import *
5556
from .QuadraticEnvelope import *
57+
from .SingularValuePenalty import *
5658

5759
__all__ = ['Box', 'Simplex', 'Intersection', 'AffineSet', 'Quadratic',
5860
'Euclidean', 'EuclideanBall', 'L0Ball', 'L1', 'L1Ball', 'L2',
5961
'L2Convolve', 'L21', 'L21_plus_L1', 'Huber', 'Nuclear',
6062
'NuclearBall', 'Orthogonal', 'VStack', 'Nonlinear', 'SCAD',
61-
'Log', 'ETP', 'Geman', 'QuadraticEnvelopeCard']
63+
'Log', 'ETP', 'Geman', 'QuadraticEnvelopeCard', 'SingularValuePenalty']

pytests/test_proximal.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,21 @@
33
import numpy as np
44
from numpy.testing import assert_array_equal, assert_array_almost_equal
55
from pylops import MatrixMult, Identity
6+
7+
import pyproximal
68
from pyproximal.utils import moreau
7-
from pyproximal.proximal import Quadratic, Nonlinear, \
8-
L1, L2, Orthogonal, VStack
9+
from pyproximal.proximal import L1, L2, Nonlinear, Orthogonal, Quadratic, \
10+
SingularValuePenalty, VStack
911

1012
par1 = {'nx': 10, 'sigma': 1., 'dtype': 'float32'} # even float32
1113
par2 = {'nx': 11, 'sigma': 2., 'dtype': 'float64'} # odd float64
1214

13-
np.random.seed(10)
14-
1515

1616
@pytest.mark.parametrize("par", [(par1), (par2)])
1717
def test_Quadratic(par):
1818
"""Quadratic functional and proximal/dual proximal
1919
"""
20+
np.random.seed(10)
2021
A = np.random.normal(0, 1, (par['nx'], par['nx']))
2122
A = A.T @ A
2223
quad = Quadratic(Op=MatrixMult(A), b=np.ones(par['nx']), niter=500)
@@ -31,6 +32,7 @@ def test_Quadratic(par):
3132
def test_DotProduct(par):
3233
"""Dot product functional and proximal/dual proximal
3334
"""
35+
np.random.seed(10)
3436
quad = Quadratic(b=np.ones(par['nx']))
3537

3638
# prox / dualprox
@@ -43,6 +45,7 @@ def test_DotProduct(par):
4345
def test_Constant(par):
4446
"""Constant functional and proximal/dual proximal
4547
"""
48+
np.random.seed(10)
4649
quad = Quadratic(c=5.)
4750

4851
# prox / dualprox
@@ -55,6 +58,7 @@ def test_Constant(par):
5558
def test_SemiOrthogonal(par):
5659
"""L1 functional with Semi-Orthogonal operator and proximal/dual proximal
5760
"""
61+
np.random.seed(10)
5862
l1 = L1()
5963
orth = Orthogonal(l1, 2*Identity(par['nx']), b=np.arange(par['nx']),
6064
partial=True, alpha=4.)
@@ -69,6 +73,7 @@ def test_SemiOrthogonal(par):
6973
def test_Orthogonal(par):
7074
"""L1 functional with Orthogonal operator and proximal/dual proximal
7175
"""
76+
np.random.seed(10)
7277
l1 = L1()
7378
orth = Orthogonal(l1, Identity(par['nx']), b=np.arange(par['nx']))
7479

@@ -82,6 +87,7 @@ def test_Orthogonal(par):
8287
def test_VStack(par):
8388
"""L2 functional with VStack operator of multiple L1s
8489
"""
90+
np.random.seed(10)
8591
nxs = [par['nx'] // 4] * 4
8692
nxs[-1] = par['nx'] - np.sum(nxs[:-1])
8793
l2 = L2()
@@ -106,6 +112,7 @@ def test_Nonlinear():
106112
"""Nonlinear proximal operator. Since this is a template class simply check
107113
that errors are raised when not used properly
108114
"""
115+
np.random.seed(10)
109116
Nop = Nonlinear(np.ones(10))
110117
with pytest.raises(NotImplementedError):
111118
Nop.fun(np.ones(10))
@@ -115,4 +122,20 @@ def test_Nonlinear():
115122
Nop.optimize()
116123

117124

125+
@pytest.mark.parametrize("par", [(par1), (par2)])
126+
def test_SingularValuePenalty(par):
127+
"""Test SingularValuePenalty
128+
"""
129+
np.random.seed(10)
130+
f_mu = pyproximal.QuadraticEnvelopeCard(mu=par['sigma'])
131+
penalty = SingularValuePenalty((par['nx'], 2 * par['nx']), f_mu)
132+
133+
# norm, cross-check with svd (use tolerance as two methods don't provide
134+
# the exact same eigenvalues)
135+
X = np.random.uniform(0., 0.1, (par['nx'], 2 * par['nx'])).astype(par['dtype'])
136+
_, S, _ = np.linalg.svd(X)
137+
assert (penalty(X.ravel()) - f_mu(S)) < 1e-3
118138

139+
# prox / dualprox
140+
tau = 0.75
141+
assert moreau(penalty, X.ravel(), tau)

0 commit comments

Comments
 (0)