|
41 | 41 | class BinaryOperator(AffAtom): |
42 | 42 | """ |
43 | 43 | Base class for expressions involving binary operators. (other than addition) |
44 | | -
|
45 | 44 | """ |
46 | 45 |
|
47 | 46 | OP_NAME = 'BINARY_OP' |
@@ -172,37 +171,40 @@ def is_decr(self, idx) -> bool: |
172 | 171 | return self.args[1-idx].is_nonpos() |
173 | 172 |
|
174 | 173 | def _grad(self, values): |
175 | | - """Gives the (sub/super)gradient of the atom w.r.t. each argument. |
176 | | -
|
177 | | - Matrix expressions are vectorized, so the gradient is a matrix. |
178 | | -
|
| 174 | + """Compute the gradient of matrix multiplication w.r.t. each argument. |
| 175 | + |
| 176 | + For Z = X @ Y, this computes the Jacobian matrices: |
| 177 | + - ∂vec(Z)/∂vec(X) = Y.T ⊗ I_m where X is (m, n) |
| 178 | + - ∂vec(Z)/∂vec(Y) = I_p ⊗ X where Y is (n, p) |
| 179 | + |
179 | 180 | Args: |
180 | | - values: A list of numeric values for the arguments. |
181 | | -
|
| 181 | + values: A list of numeric values for the arguments [X, Y]. |
| 182 | + |
182 | 183 | Returns: |
183 | | - A list of SciPy CSC sparse matrices or None. |
| 184 | + A list of SciPy CSC sparse matrices [DX, DY]. |
184 | 185 | """ |
| 186 | + # Handle constant cases |
185 | 187 | if self.args[0].is_constant() or self.args[1].is_constant(): |
186 | 188 | return super(MulExpression, self)._grad(values) |
187 | | - |
188 | | - # TODO(akshayka): Verify that the following code is correct for |
189 | | - # non-affine arguments. |
| 189 | + |
190 | 190 | X = values[0] |
191 | 191 | Y = values[1] |
192 | | - |
193 | | - DX_rows = self.args[0].size |
194 | | - cols = self.args[0].size |
195 | | - |
196 | | - # DX = [diag(Y11), diag(Y12), ...] |
197 | | - # [diag(Y21), diag(Y22), ...] |
198 | | - # [ ... ... ...] |
199 | | - DX = sp.dok_array((DX_rows, cols)) |
200 | | - for k in range(self.args[0].shape[0]): |
201 | | - DX[k::self.args[0].shape[0], k::self.args[0].shape[0]] = Y |
202 | | - DX = sp.csc_array(DX) |
203 | | - cols = 1 if len(self.args[1].shape) == 1 else self.args[1].shape[1] |
204 | | - DY = sp.block_diag([X.T for k in range(cols)], 'csc') |
205 | | - |
| 192 | + |
| 193 | + # Get dimensions |
| 194 | + m, n = self.args[0].shape if len(self.args[0].shape) == 2 else (self.args[0].size, 1) |
| 195 | + n2, p = self.args[1].shape if len(self.args[1].shape) == 2 else (self.args[1].size, 1) |
| 196 | + |
| 197 | + # Verify dimension compatibility |
| 198 | + assert n == n2, f"Inner dimensions must match for multiplication: {n} != {n2}" |
| 199 | + |
| 200 | + # Compute ∂vec(Z)/∂vec(X) = Y.T ⊗ I_m |
| 201 | + # This is a (m*p) × (m*n) matrix |
| 202 | + DX = sp.kron(Y.T, sp.eye(m, format='csc'), format='csc') |
| 203 | + |
| 204 | + # Compute ∂vec(Z)/∂vec(Y) = I_p ⊗ X |
| 205 | + # This is a (m*p) × (n*p) matrix |
| 206 | + DY = sp.kron(sp.eye(p, format='csc'), X, format='csc') |
| 207 | + |
206 | 208 | return [DX, DY] |
207 | 209 |
|
208 | 210 | def graph_implementation( |
|
0 commit comments