@@ -16,6 +16,11 @@ class BilinearOperator():
1616 - ``lx``: Lipschitz constant of :math:`\nabla_x H`
1717 - ``ly``: Lipschitz constant of :math:`\nabla_y H`
1818
19+ Two additional methods (``updatex`` and ``updatey``) are provided to
20+ update the :math:`\mathbf{x}` and :math:`\mathbf{x}` internal
21+ variables. It is user responsability to choose when to invoke such
22+ method (i.e., when to update the internal variables).
23+
1924 Notes
2025 -----
2126 A bilinear operator is defined as a differentiable nonlinear function
@@ -45,15 +50,17 @@ def ly(self, y):
4550 pass
4651
4752 def updatex (self , x ):
48- """Update x variable (used when evaluating the gradient over y
53+ """Update x variable (to be used to update the internal variable x)
4954 """
5055 self .x = x
5156
5257 def updatey (self , y ):
53- """Update y variable (used when evaluating the gradient over y
58+ """Update y variable (to be used to update the internal variable y)
5459 """
5560 self .y = y
5661
62+ def updatexy (self , xy ):
63+ pass
5764
5865class LowRankFactorizedMatrix (BilinearOperator ):
5966 r"""Low-Rank Factorized Matrix operator.
@@ -83,16 +90,21 @@ class LowRankFactorizedMatrix(BilinearOperator):
8390
8491 .. math::
8592
86- \nabla_x H = \mathbf{Op}^H(\mathbf{Op}(\mathbf{X}\mathbf{Y})
93+ \nabla_x H(\mathbf{x};\ mathbf{y}) =
94+ \mathbf{Op}^H(\mathbf{Op}(\mathbf{X}\mathbf{Y})
8795 - \mathbf{d})\mathbf{Y}^H
8896
8997 and gradient with respect to y equal to:
9098
9199 .. math::
92100
93- \nabla_y H = \mathbf{X}^H \mathbf{Op}^H(\mathbf{Op}
101+ \nabla_y H(\mathbf{y}; \mathbf{x}) =
102+ \mathbf{X}^H \mathbf{Op}^H(\mathbf{Op}
94103 (\mathbf{X}\mathbf{Y}) - \mathbf{d})
95104
105+ Note that in both cases, the currently stored x/y is used for
106+ the second variable within parenthesis (after ;)
107+
96108 """
97109 def __init__ (self , X , Y , d , Op = None ):
98110 self .n , self .k = X .shape
@@ -160,7 +172,7 @@ def gradx(self, x):
160172 r = (self .Op .H @ r ).reshape (self .n , self .m )
161173 else :
162174 r = r .reshape (self .n , self .m )
163- g = - r @ self .y .reshape (self .k , self .m ).T
175+ g = - r @ np . conj ( self .y .reshape (self .k , self .m ).T )
164176 return g .ravel ()
165177
166178 def grady (self , y ):
@@ -169,13 +181,15 @@ def grady(self, y):
169181 r = (self .Op .H @ r .ravel ()).reshape (self .n , self .m )
170182 else :
171183 r = r .reshape (self .n , self .m )
172- g = - self .x .reshape (self .n , self .k ).T @ r
184+ g = - np . conj ( self .x .reshape (self .n , self .k ).T ) @ r
173185 return g .ravel ()
174186
175187 def grad (self , x ):
176- self .updatex (x [:self .n * self .k ])
177- self .updatey (x [self .n * self .k :])
178188 gx = self .gradx (x [:self .n * self .k ])
179189 gy = self .grady (x [self .n * self .k :])
180190 g = np .hstack ([gx , gy ])
181191 return g
192+
193+ def updatexy (self , x ):
194+ self .updatex (x [:self .n * self .k ])
195+ self .updatey (x [self .n * self .k :])
0 commit comments