Skip to content

Commit 14c2e66

Browse files
committed
Added grad to LowRankFactorizedMatrix
1 parent 2aa9fe5 commit 14c2e66

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

pyproximal/optimization/segmentation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def Segment(y, cl, sigma, alpha, clsigmas=None, z=None, niter=10, x0=None,
6868
:math:`\mathbf{\sigma}=[\sigma_1, ..., \sigma_{N_{cl}}]^T` are vectors
6969
representing the optimal mean and standard deviations for each class.
7070
71-
.. [1] Chambolle, and A., Pock, "A first-order primal-dual algorithm for
71+
.. [1] Chambolle, and A., Pock, "A first-order primal-dual algorithm for
7272
convex problems with applications to imaging", Journal of Mathematical
7373
Imaging and Vision, 40, 8pp. 120–145. 2011.
7474

pyproximal/utils/bilinear.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ class BilinearOperator():
1111
:math:`\nabla_x H`
1212
- ``grady``: a method evaluating the gradient over :math:`\mathbf{y}`:
1313
:math:`\nabla_y H`
14+
- ``grad``: a method returning the stacked gradient vector over
15+
:math:`\mathbf{x},\mathbf{y}`: :math:`[\nabla_x H`, [\nabla_y H]`
1416
- ``lx``: Lipschitz constant of :math:`\nabla_x H`
1517
- ``ly``: Lipschitz constant of :math:`\nabla_y H`
1618
@@ -33,6 +35,9 @@ def gradx(self, x):
3335
def grady(self, y):
3436
pass
3537

38+
def grad(self, y):
39+
pass
40+
3641
def lx(self, x):
3742
pass
3843

@@ -100,7 +105,9 @@ def __init__(self, X, Y, d, Op=None):
100105
self.shapex = (self.n * self.m, self.n * self.k)
101106
self.shapey = (self.n * self.m, self.m * self.k)
102107

103-
def __call__(self, x, y):
108+
def __call__(self, x, y=None):
109+
if y is None:
110+
x, y = x[:self.n * self.k], x[self.n * self.k:]
104111
xold = self.x.copy()
105112
self.updatex(x)
106113
res = self.d - self._matvecy(y)
@@ -159,3 +166,11 @@ def grady(self, y):
159166
r = r.reshape(self.n, self.m)
160167
g = -self.x.reshape(self.n, self.k).T @ r
161168
return g.ravel()
169+
170+
def grad(self, x):
171+
self.updatex(x[:self.n * self.k])
172+
self.updatey(x[self.n * self.k:])
173+
gx = self.gradx(x[:self.n * self.k])
174+
gy = self.grady(x[self.n * self.k:])
175+
g = np.hstack([gx, gy])
176+
return g

0 commit comments

Comments
 (0)