Skip to content

Commit c9a7e06

Browse files
authored
Merge pull request #124 from mrava87/dev
feature: added bilinear update to ProximalGradient
2 parents b18af36 + 39d5a6d commit c9a7e06

File tree

3 files changed

+79
-9
lines changed

3 files changed

+79
-9
lines changed

pyproximal/optimization/primal.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pylops.optimization.leastsquares import regularized_inversion
77
from pylops.utils.backend import to_numpy
88
from pyproximal.proximal import L2
9+
from pyproximal.utils.bilinear import BilinearOperator
910

1011

1112
def _backtracking(x, tau, proxf, proxg, epsg, beta=0.5, niterback=10):
@@ -239,7 +240,11 @@ def ProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
239240
else:
240241
x, tau = _backtracking(y, tau, proxf, proxg, epsg,
241242
beta=beta, niterback=niterback)
242-
243+
244+
# update internal parameters for bilinear operator
245+
if isinstance(proxf, BilinearOperator):
246+
proxf.updatexy(x)
247+
243248
# update y
244249
if acceleration == 'vandenberghe':
245250
omega = iiter / (iiter + 3)

pyproximal/utils/bilinear.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ class BilinearOperator():
1616
- ``lx``: Lipschitz constant of :math:`\nabla_x H`
1717
- ``ly``: Lipschitz constant of :math:`\nabla_y H`
1818
19+
Two additional methods (``updatex`` and ``updatey``) are provided to
20+
update the :math:`\mathbf{x}` and :math:`\mathbf{x}` internal
21+
variables. It is user responsability to choose when to invoke such
22+
method (i.e., when to update the internal variables).
23+
1924
Notes
2025
-----
2126
A bilinear operator is defined as a differentiable nonlinear function
@@ -45,15 +50,18 @@ def ly(self, y):
4550
pass
4651

4752
def updatex(self, x):
48-
"""Update x variable (used when evaluating the gradient over y
53+
"""Update x variable (to be used to update the internal variable x)
4954
"""
5055
self.x = x
5156

5257
def updatey(self, y):
53-
"""Update y variable (used when evaluating the gradient over y
58+
"""Update y variable (to be used to update the internal variable y)
5459
"""
5560
self.y = y
5661

62+
def updatexy(self, xy):
63+
pass
64+
5765

5866
class LowRankFactorizedMatrix(BilinearOperator):
5967
r"""Low-Rank Factorized Matrix operator.
@@ -83,16 +91,21 @@ class LowRankFactorizedMatrix(BilinearOperator):
8391
8492
.. math::
8593
86-
\nabla_x H = \mathbf{Op}^H(\mathbf{Op}(\mathbf{X}\mathbf{Y})
94+
\nabla_x H(\mathbf{x};\ mathbf{y}) =
95+
\mathbf{Op}^H(\mathbf{Op}(\mathbf{X}\mathbf{Y})
8796
- \mathbf{d})\mathbf{Y}^H
8897
8998
and gradient with respect to y equal to:
9099
91100
.. math::
92101
93-
\nabla_y H = \mathbf{X}^H \mathbf{Op}^H(\mathbf{Op}
102+
\nabla_y H(\mathbf{y}; \mathbf{x}) =
103+
\mathbf{X}^H \mathbf{Op}^H(\mathbf{Op}
94104
(\mathbf{X}\mathbf{Y}) - \mathbf{d})
95105
106+
Note that in both cases, the currently stored x/y is used for
107+
the second variable within parenthesis (after ;)
108+
96109
"""
97110
def __init__(self, X, Y, d, Op=None):
98111
self.n, self.k = X.shape
@@ -160,7 +173,7 @@ def gradx(self, x):
160173
r = (self.Op.H @ r).reshape(self.n, self.m)
161174
else:
162175
r = r.reshape(self.n, self.m)
163-
g = -r @ self.y.reshape(self.k, self.m).T
176+
g = -r @ np.conj(self.y.reshape(self.k, self.m).T)
164177
return g.ravel()
165178

166179
def grady(self, y):
@@ -169,13 +182,15 @@ def grady(self, y):
169182
r = (self.Op.H @ r.ravel()).reshape(self.n, self.m)
170183
else:
171184
r = r.reshape(self.n, self.m)
172-
g = -self.x.reshape(self.n, self.k).T @ r
185+
g = -np.conj(self.x.reshape(self.n, self.k).T) @ r
173186
return g.ravel()
174187

175188
def grad(self, x):
176-
self.updatex(x[:self.n * self.k])
177-
self.updatey(x[self.n * self.k:])
178189
gx = self.gradx(x[:self.n * self.k])
179190
gy = self.grady(x[self.n * self.k:])
180191
g = np.hstack([gx, gy])
181192
return g
193+
194+
def updatexy(self, x):
195+
self.updatex(x[:self.n * self.k])
196+
self.updatey(x[self.n * self.k:])

pytests/test_bilinear.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
3+
import numpy as np
4+
from numpy.testing import assert_array_equal
5+
from pylops import Diagonal
6+
7+
from pyproximal.utils.bilinear import LowRankFactorizedMatrix
8+
9+
10+
par1 = {'n': 21, 'm': 11, 'k': 5, 'imag': 0,
11+
'dtype': 'float32'} # real
12+
par2 = {'n': 21, 'm': 11, 'k': 5, 'imag': 1j,
13+
'dtype': 'complex64'} # complex
14+
15+
np.random.seed(10)
16+
17+
18+
@pytest.mark.parametrize("par", [(par1), (par2)])
19+
def test_lrfactorized(par):
20+
"""Check equivalence of matvec operations in LowRankFactorizedMatrix
21+
"""
22+
U = (np.random.normal(0., 1., (par['n'], par['k'])) + \
23+
par['imag'] * np.random.normal(0., 1., (par['n'], par['k']))).astype(par['dtype'])
24+
V = (np.random.normal(0., 1., (par['m'], par['k'])) + \
25+
par['imag'] * np.random.normal(0., 1., (par['m'], par['k']))).astype(par['dtype'])
26+
27+
X = U @ V.T
28+
LOp = LowRankFactorizedMatrix(U, V.T, X)
29+
30+
assert_array_equal(X.ravel(), LOp._matvecx(U.ravel()))
31+
assert_array_equal(X.ravel(), LOp._matvecy(V.T.ravel()))
32+
33+
34+
@pytest.mark.parametrize("par", [(par1), (par2)])
35+
def test_lrfactorizedoperator(par):
36+
"""Check equivalence of matvec operations in LowRankFactorizedMatrix with operator Op
37+
"""
38+
U = (np.random.normal(0., 1., (par['n'], par['k'])) + \
39+
par['imag'] * np.random.normal(0., 1., (par['n'], par['k']))).astype(par['dtype'])
40+
V = (np.random.normal(0., 1., (par['m'], par['k'])) + \
41+
par['imag'] * np.random.normal(0., 1., (par['m'], par['k']))).astype(par['dtype'])
42+
Op = Diagonal(np.arange(par['n'] * par['m']) + 1.)
43+
44+
X = U @ V.T
45+
y = Op @ X.ravel()
46+
47+
LOp = LowRankFactorizedMatrix(U, V.T, y, Op)
48+
49+
assert_array_equal(y, LOp._matvecx(U.ravel()))
50+
assert_array_equal(y, LOp._matvecy(V.T.ravel()))

0 commit comments

Comments
 (0)