-
Notifications
You must be signed in to change notification settings - Fork 145
Implement gradient for QR decomposition #1303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
a187464
ee9aaa2
0e47b7d
9e5e765
ac48c11
a6ae03b
4edc698
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -512,6 +512,78 @@ | |
else: | ||
outputs[0][0] = res | ||
|
||
def L_op(self, inputs, outputs, output_grads): | ||
""" | ||
Reverse-mode gradient of the QR function. Adapted from ..[1], which is used in the forward-mode implementation in jax here: | ||
https://github.com/jax-ml/jax/blob/54691b125ab4b6f88c751dae460e4d51f5cf834a/jax/_src/lax/linalg.py#L1803 | ||
And from ..[2] which describes a solution in the square matrix case. | ||
References | ||
---------- | ||
.. [1] Townsend, James. "Differentiating the qr decomposition." online draft https://j-towns.github.io/papers/qr-derivative.pdf (2018) | ||
.. [2] Sebastian F. Walter , Lutz Lehmann & René Lamour. "On evaluating higher-order derivatives | ||
of the QR decomposition of tall matrices with full column rank in forward and reverse mode algorithmic differentiation", | ||
Optimization Methods and Software, 27:2, 391-403, DOI: 10.1080/10556788.2011.610454 | ||
""" | ||
|
||
(A,) = (cast(ptb.TensorVariable, x) for x in inputs) | ||
*_, m, n = A.type.shape | ||
|
||
def _H(x: ptb.TensorVariable): | ||
return x.conj().T | ||
|
||
def _copyutl(x: ptb.TensorVariable): | ||
return ptb.triu(x, k=0) + _H(ptb.triu(x, k=1)) | ||
|
||
if self.mode == "raw" or (self.mode == "complete" and m != n): | ||
raise NotImplementedError("Gradient of qr not implemented") | ||
|
||
|
||
elif m < n: | ||
|
||
raise NotImplementedError( | ||
"Gradient of qr not implemented for m x n matrices with m < n" | ||
) | ||
|
||
elif self.mode == "r": | ||
# We need all the components of the QR to compute the gradient of A even if we only | ||
# use the upper triangular component in the cost function. | ||
Q, R = qr(A, mode="reduced") | ||
dR = cast(ptb.TensorVariable, output_grads[0]) | ||
R_dRt = R @ _H(dR) | ||
Rinvt = _H(inv(R)) | ||
A_bar = Q @ ((ptb.tril(R_dRt - _H(R_dRt), k=-1)) @ Rinvt + dR) | ||
|
||
return [A_bar] | ||
|
||
else: | ||
Q, R = (cast(ptb.TensorVariable, x) for x in outputs) | ||
|
||
new_output_grads = [] | ||
is_disconnected = [ | ||
isinstance(x.type, DisconnectedType) for x in output_grads | ||
] | ||
if all(is_disconnected): | ||
# This should never be reached by Pytensor | ||
return [DisconnectedType()()] # pragma: no cover | ||
|
||
for disconnected, output_grad, output in zip( | ||
is_disconnected, output_grads, [Q, R], strict=True | ||
): | ||
if disconnected: | ||
new_output_grads.append(output.zeros_like()) | ||
else: | ||
new_output_grads.append(output_grad) | ||
|
||
(dQ, dR) = (cast(ptb.TensorVariable, x) for x in new_output_grads) | ||
|
||
Rinvt = _H(inv(R)) | ||
jessegrabowski marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
Qt_dQ = _H(Q) @ dQ | ||
R_dRt = R @ _H(dR) | ||
A_bar = ( | ||
Q @ (ptb.tril(R_dRt - _H(R_dRt), k=-1) - _copyutl(Qt_dQ)) + dQ | ||
) @ Rinvt + Q @ dR | ||
|
||
return [A_bar] | ||
|
||
|
||
def qr(a, mode="reduced"): | ||
""" | ||
|
Uh oh!
There was an error while loading. Please reload this page.