Skip to content

Commit 6003bc7

Browse files
committed
feature: added bilinear update to ProximalGradient
1 parent 9809201 commit 6003bc7

File tree

3 files changed

+78
-9
lines changed

3 files changed

+78
-9
lines changed

pyproximal/optimization/primal.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pylops.optimization.leastsquares import regularized_inversion
77
from pylops.utils.backend import to_numpy
88
from pyproximal.proximal import L2
9+
from pyproximal.utils.bilinear import BilinearOperator
910

1011

1112
def _backtracking(x, tau, proxf, proxg, epsg, beta=0.5, niterback=10):
@@ -239,7 +240,11 @@ def ProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
239240
else:
240241
x, tau = _backtracking(y, tau, proxf, proxg, epsg,
241242
beta=beta, niterback=niterback)
242-
243+
244+
# update internal parameters for bilinear operator
245+
if isinstance(proxf, BilinearOperator):
246+
proxf.updatexy(x)
247+
243248
# update y
244249
if acceleration == 'vandenberghe':
245250
omega = iiter / (iiter + 3)

pyproximal/utils/bilinear.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ class BilinearOperator():
1616
- ``lx``: Lipschitz constant of :math:`\nabla_x H`
1717
- ``ly``: Lipschitz constant of :math:`\nabla_y H`
1818
19+
Two additional methods (``updatex`` and ``updatey``) are provided to
20+
update the :math:`\mathbf{x}` and :math:`\mathbf{x}` internal
21+
variables. It is user responsability to choose when to invoke such
22+
method (i.e., when to update the internal variables).
23+
1924
Notes
2025
-----
2126
A bilinear operator is defined as a differentiable nonlinear function
@@ -45,15 +50,17 @@ def ly(self, y):
4550
pass
4651

4752
def updatex(self, x):
48-
"""Update x variable (used when evaluating the gradient over y
53+
"""Update x variable (to be used to update the internal variable x)
4954
"""
5055
self.x = x
5156

5257
def updatey(self, y):
53-
"""Update y variable (used when evaluating the gradient over y
58+
"""Update y variable (to be used to update the internal variable y)
5459
"""
5560
self.y = y
5661

62+
def updatexy(self, xy):
63+
pass
5764

5865
class LowRankFactorizedMatrix(BilinearOperator):
5966
r"""Low-Rank Factorized Matrix operator.
@@ -83,16 +90,21 @@ class LowRankFactorizedMatrix(BilinearOperator):
8390
8491
.. math::
8592
86-
\nabla_x H = \mathbf{Op}^H(\mathbf{Op}(\mathbf{X}\mathbf{Y})
93+
\nabla_x H(\mathbf{x};\ mathbf{y}) =
94+
\mathbf{Op}^H(\mathbf{Op}(\mathbf{X}\mathbf{Y})
8795
- \mathbf{d})\mathbf{Y}^H
8896
8997
and gradient with respect to y equal to:
9098
9199
.. math::
92100
93-
\nabla_y H = \mathbf{X}^H \mathbf{Op}^H(\mathbf{Op}
101+
\nabla_y H(\mathbf{y}; \mathbf{x}) =
102+
\mathbf{X}^H \mathbf{Op}^H(\mathbf{Op}
94103
(\mathbf{X}\mathbf{Y}) - \mathbf{d})
95104
105+
Note that in both cases, the currently stored x/y is used for
106+
the second variable within parenthesis (after ;)
107+
96108
"""
97109
def __init__(self, X, Y, d, Op=None):
98110
self.n, self.k = X.shape
@@ -160,7 +172,7 @@ def gradx(self, x):
160172
r = (self.Op.H @ r).reshape(self.n, self.m)
161173
else:
162174
r = r.reshape(self.n, self.m)
163-
g = -r @ self.y.reshape(self.k, self.m).T
175+
g = -r @ np.conj(self.y.reshape(self.k, self.m).T)
164176
return g.ravel()
165177

166178
def grady(self, y):
@@ -169,13 +181,15 @@ def grady(self, y):
169181
r = (self.Op.H @ r.ravel()).reshape(self.n, self.m)
170182
else:
171183
r = r.reshape(self.n, self.m)
172-
g = -self.x.reshape(self.n, self.k).T @ r
184+
g = -np.conj(self.x.reshape(self.n, self.k).T) @ r
173185
return g.ravel()
174186

175187
def grad(self, x):
176-
self.updatex(x[:self.n * self.k])
177-
self.updatey(x[self.n * self.k:])
178188
gx = self.gradx(x[:self.n * self.k])
179189
gy = self.grady(x[self.n * self.k:])
180190
g = np.hstack([gx, gy])
181191
return g
192+
193+
def updatexy(self, x):
194+
self.updatex(x[:self.n * self.k])
195+
self.updatey(x[self.n * self.k:])

pytests/test_bilinear.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
3+
import numpy as np
4+
from numpy.testing import assert_array_equal
5+
from pylops import Diagonal
6+
7+
from pyproximal.utils.bilinear import LowRankFactorizedMatrix
8+
9+
10+
par1 = {'n': 21, 'm': 11, 'k': 5, 'imag': 0,
11+
'dtype': 'float32'} # real
12+
par2 = {'n': 21, 'm': 11, 'k': 5, 'imag': 1j,
13+
'dtype': 'complex64'} # complex
14+
15+
np.random.seed(10)
16+
17+
18+
@pytest.mark.parametrize("par", [(par1), (par2)])
19+
def test_lrfactorized(par):
20+
"""Check equivalence of matvec operations in LowRankFactorizedMatrix
21+
"""
22+
U = (np.random.normal(0., 1., (par['n'], par['k'])) + \
23+
par['imag'] * np.random.normal(0., 1., (par['n'], par['k']))).astype(par['dtype'])
24+
V = (np.random.normal(0., 1., (par['m'], par['k'])) + \
25+
par['imag'] * np.random.normal(0., 1., (par['m'], par['k']))).astype(par['dtype'])
26+
27+
X = U @ V.T
28+
LOp = LowRankFactorizedMatrix(U, V.T, X)
29+
30+
assert_array_equal(X.ravel(), LOp._matvecx(U.ravel()))
31+
assert_array_equal(X.ravel(), LOp._matvecy(V.T.ravel()))
32+
33+
34+
@pytest.mark.parametrize("par", [(par1), (par2)])
35+
def test_lrfactorizedoperator(par):
36+
"""Check equivalence of matvec operations in LowRankFactorizedMatrix with operator Op
37+
"""
38+
U = (np.random.normal(0., 1., (par['n'], par['k'])) + \
39+
par['imag'] * np.random.normal(0., 1., (par['n'], par['k']))).astype(par['dtype'])
40+
V = (np.random.normal(0., 1., (par['m'], par['k'])) + \
41+
par['imag'] * np.random.normal(0., 1., (par['m'], par['k']))).astype(par['dtype'])
42+
Op = Diagonal(np.arange(par['n'] * par['m']) + 1.)
43+
44+
X = U @ V.T
45+
y = Op @ X.ravel()
46+
47+
LOp = LowRankFactorizedMatrix(U, V.T, y, Op)
48+
49+
assert_array_equal(y, LOp._matvecx(U.ravel()))
50+
assert_array_equal(y, LOp._matvecy(V.T.ravel()))

0 commit comments

Comments
 (0)