@@ -16,6 +16,11 @@ class BilinearOperator():
1616 - ``lx``: Lipschitz constant of :math:`\nabla_x H`
1717 - ``ly``: Lipschitz constant of :math:`\nabla_y H`
1818
19+ Two additional methods (``updatex`` and ``updatey``) are provided to
20+ update the :math:`\mathbf{x}` and :math:`\mathbf{x}` internal
21+ variables. It is user responsability to choose when to invoke such
22+ method (i.e., when to update the internal variables).
23+
1924 Notes
2025 -----
2126 A bilinear operator is defined as a differentiable nonlinear function
@@ -45,15 +50,18 @@ def ly(self, y):
4550 pass
4651
4752 def updatex (self , x ):
48- """Update x variable (used when evaluating the gradient over y
53+ """Update x variable (to be used to update the internal variable x)
4954 """
5055 self .x = x
5156
5257 def updatey (self , y ):
53- """Update y variable (used when evaluating the gradient over y
58+ """Update y variable (to be used to update the internal variable y)
5459 """
5560 self .y = y
5661
62+ def updatexy (self , xy ):
63+ pass
64+
5765
5866class LowRankFactorizedMatrix (BilinearOperator ):
5967 r"""Low-Rank Factorized Matrix operator.
@@ -83,16 +91,21 @@ class LowRankFactorizedMatrix(BilinearOperator):
8391
8492 .. math::
8593
86- \nabla_x H = \mathbf{Op}^H(\mathbf{Op}(\mathbf{X}\mathbf{Y})
94+ \nabla_x H(\mathbf{x};\ mathbf{y}) =
95+ \mathbf{Op}^H(\mathbf{Op}(\mathbf{X}\mathbf{Y})
8796 - \mathbf{d})\mathbf{Y}^H
8897
8998 and gradient with respect to y equal to:
9099
91100 .. math::
92101
93- \nabla_y H = \mathbf{X}^H \mathbf{Op}^H(\mathbf{Op}
102+ \nabla_y H(\mathbf{y}; \mathbf{x}) =
103+ \mathbf{X}^H \mathbf{Op}^H(\mathbf{Op}
94104 (\mathbf{X}\mathbf{Y}) - \mathbf{d})
95105
106+ Note that in both cases, the currently stored x/y is used for
107+ the second variable within parenthesis (after ;)
108+
96109 """
97110 def __init__ (self , X , Y , d , Op = None ):
98111 self .n , self .k = X .shape
@@ -160,7 +173,7 @@ def gradx(self, x):
160173 r = (self .Op .H @ r ).reshape (self .n , self .m )
161174 else :
162175 r = r .reshape (self .n , self .m )
163- g = - r @ self .y .reshape (self .k , self .m ).T
176+ g = - r @ np . conj ( self .y .reshape (self .k , self .m ).T )
164177 return g .ravel ()
165178
166179 def grady (self , y ):
@@ -169,13 +182,15 @@ def grady(self, y):
169182 r = (self .Op .H @ r .ravel ()).reshape (self .n , self .m )
170183 else :
171184 r = r .reshape (self .n , self .m )
172- g = - self .x .reshape (self .n , self .k ).T @ r
185+ g = - np . conj ( self .x .reshape (self .n , self .k ).T ) @ r
173186 return g .ravel ()
174187
175188 def grad (self , x ):
176- self .updatex (x [:self .n * self .k ])
177- self .updatey (x [self .n * self .k :])
178189 gx = self .gradx (x [:self .n * self .k ])
179190 gy = self .grady (x [self .n * self .k :])
180191 g = np .hstack ([gx , gy ])
181192 return g
193+
194+ def updatexy (self , x ):
195+ self .updatex (x [:self .n * self .k ])
196+ self .updatey (x [self .n * self .k :])
0 commit comments