@@ -11,6 +11,8 @@ class BilinearOperator():
1111 :math:`\nabla_x H`
1212 - ``grady``: a method evaluating the gradient over :math:`\mathbf{y}`:
1313 :math:`\nabla_y H`
14+ - ``grad``: a method returning the stacked gradient vector over
15+ :math:`\mathbf{x},\mathbf{y}`: :math:`[\nabla_x H`, [\nabla_y H]`
1416 - ``lx``: Lipschitz constant of :math:`\nabla_x H`
1517 - ``ly``: Lipschitz constant of :math:`\nabla_y H`
1618
@@ -33,6 +35,9 @@ def gradx(self, x):
3335 def grady (self , y ):
3436 pass
3537
38+ def grad (self , y ):
39+ pass
40+
3641 def lx (self , x ):
3742 pass
3843
@@ -100,7 +105,9 @@ def __init__(self, X, Y, d, Op=None):
100105 self .shapex = (self .n * self .m , self .n * self .k )
101106 self .shapey = (self .n * self .m , self .m * self .k )
102107
103- def __call__ (self , x , y ):
108+ def __call__ (self , x , y = None ):
109+ if y is None :
110+ x , y = x [:self .n * self .k ], x [self .n * self .k :]
104111 xold = self .x .copy ()
105112 self .updatex (x )
106113 res = self .d - self ._matvecy (y )
@@ -159,3 +166,11 @@ def grady(self, y):
159166 r = r .reshape (self .n , self .m )
160167 g = - self .x .reshape (self .n , self .k ).T @ r
161168 return g .ravel ()
169+
170+ def grad (self , x ):
171+ self .updatex (x [:self .n * self .k ])
172+ self .updatey (x [self .n * self .k :])
173+ gx = self .gradx (x [:self .n * self .k ])
174+ gy = self .grady (x [self .n * self .k :])
175+ g = np .hstack ([gx , gy ])
176+ return g
0 commit comments