Skip to content

Commit ffd376d

Browse files
authored
Merge pull request #72 from marcusvaltonen/user/marcus/weighted-nuclear-norm
Add support for weighted nuclear norm
2 parents 974547e + a8b9e45 commit ffd376d

File tree

2 files changed

+53
-13
lines changed

2 files changed

+53
-13
lines changed

pyproximal/proximal/Nuclear.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,30 +9,52 @@
99
class Nuclear(ProxOperator):
1010
r"""Nuclear norm proximal operator.
1111
12-
Proximal operator of the Nuclear norm defined as
13-
:math:`\sigma\|\mathbf{X}\|_* = \sigma \sum_i \sigma_i` where
14-
:math:`\mathbf{X}` is a matrix of size :math:`M \times N` and
15-
:math:`\sigma_i=1, \ldots, \min(M, N)` are its singular values.
12+
The nuclear norm is defined as
13+
:math:`\sigma\|\mathbf{X}\|_* = \sigma \sum_i \lambda_i` where :math:`\mathbf{X}`
14+
is a matrix of size :math:`M \times N` and :math:`\lambda_i` is the *i*:th
15+
singular value of :math:`\mathbf{X}`, where :math:`i=1,\ldots, \min(M, N)`.
16+
17+
The *weighted* nuclear norm, with the positive weight vector :math:`\boldsymbol\sigma`, is
18+
defined as
19+
20+
.. math::
21+
22+
\|\mathbf{X}|\|_{{\boldsymbol\sigma},*} = \sum_i \sigma_i\lambda_i(\mathbf{X}) .
1623
1724
Parameters
1825
----------
1926
dim : :obj:`tuple`
2027
Size of matrix :math:`\mathbf{X}`
21-
sigma : :obj:`float`, optional
22-
Multiplicative coefficient of nuclear norm
28+
sigma : :obj:`float` or :obj:`numpy.ndarray`, optional
29+
Multiplicative coefficient of the nuclear norm penalty. If ``sigma`` is a float
30+
the same penalty is applied for all singular values. If instead ``sigma`` is an
31+
array the weight ``sigma[i]`` will be applied to the *i*:th singular value.
32+
This is often referred to as the *weighted nuclear norm*.
2333
2434
Notes
2535
-----
26-
The Nuclear norm proximal operator is defined as:
36+
The nuclear norm proximal operator is:
2737
2838
.. math::
2939
3040
\prox_{\tau \sigma \|\cdot\|_*}(\mathbf{X}) =
31-
\mathbf{U} \diag\left( \prox_{\tau \sigma \|\cdot\|_1}(\boldsymbol\lambda)\right) \mathbf{V}^H
41+
\mathbf{U} \diag \{ \prox_{\tau \sigma \|\cdot\|_1}(\boldsymbol\lambda) \} \mathbf{V}^H
3242
33-
where :math:`\mathbf{X} = \mathbf{U}\diag(\boldsymbol\lambda)\mathbf{V}^H`, is an SVD of :math:`X`.
43+
where :math:`\mathbf{U}`, :math:`\boldsymbol\lambda`, and
44+
:math:`\mathbf{V}` define the SVD of :math:`X`.
45+
46+
The weighted nuclear norm is convex if the sequence :math:`\{\sigma_i\}_i` is
47+
non-ascending, but is in general non-convex; however, when the weights are
48+
non-descending it can be shown that applying the soft-thresholding operator on the
49+
singular values still yields a fixed point (w. r. t. a specific algorithm), see
50+
[1]_ for details.
51+
52+
.. [1] Gu et al. "Weighted Nuclear Norm Minimization with Application to Image
53+
Denoising", In the IEEE Conference on Computer Vision and Pattern Recognition,
54+
2862-2869, 2014.
3455
3556
"""
57+
3658
def __init__(self, dim, sigma=1.):
3759
super().__init__(None, False)
3860
self.dim = dim
@@ -42,14 +64,14 @@ def __call__(self, x):
4264
X = x.reshape(self.dim)
4365
eigs = np.linalg.eigvalsh(X.T @ X)
4466
eigs[eigs < 0] = 0 # ensure all eigenvalues at positive
45-
nucl = np.sum(np.sqrt(eigs))
46-
return self.sigma * nucl
67+
return np.sum(np.flip(self.sigma) * np.sqrt(eigs))
4768

4869
@_check_tau
4970
def prox(self, x, tau):
5071
X = x.reshape(self.dim)
5172
U, S, Vh = np.linalg.svd(X, full_matrices=False)
52-
Sth = _softthreshold(S, tau * self.sigma)
73+
sigma = self.sigma if np.isscalar(self.sigma) else self.sigma[:S.size]
74+
Sth = _softthreshold(S, tau * sigma)
5375
X = np.dot(U * Sth, Vh)
5476
return X.ravel()
5577

pytests/test_norms.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,4 +194,22 @@ def test_Nuclear(par):
194194

195195
# prox / dualprox
196196
tau = 2.
197-
assert moreau(nucl, X.ravel(), tau)
197+
assert moreau(nucl, X.ravel(), tau)
198+
199+
200+
@pytest.mark.parametrize("par", [(par1), (par2)])
201+
def test_Weighted_Nuclear(par):
202+
"""Weighted nuclear norm and proximal/dual proximal
203+
"""
204+
weights = par['sigma'] * np.linspace(0.1, 5, 2 * par['nx'])
205+
nucl = Nuclear((par['nx'], 2 * par['nx']), sigma=weights)
206+
207+
# norm, cross-check with svd (use tolerance as two methods don't provide
208+
# the exact same singular values)
209+
X = np.random.uniform(0., 0.1, (par['nx'], 2 * par['nx'])).astype(par['dtype'])
210+
S = np.linalg.svd(X, compute_uv=False)
211+
assert (nucl(X.ravel()) - np.sum(weights[:S.size] * S)) < 1e-3
212+
213+
# prox / dualprox
214+
tau = 2.
215+
assert moreau(nucl, X.ravel(), tau)

0 commit comments

Comments
 (0)