Skip to content

Commit a187464

Browse files
author
Etienne Duchesne
committed
Implement gradient for QR decomposition
1 parent a149f6c commit a187464

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,78 @@ def perform(self, node, inputs, outputs):
512512
else:
513513
outputs[0][0] = res
514514

515+
def L_op(self, inputs, outputs, output_grads):
516+
"""
517+
Reverse-mode gradient of the QR function. Adapted from ..[1], which is used in the forward-mode implementation in jax here:
518+
https://github.com/jax-ml/jax/blob/54691b125ab4b6f88c751dae460e4d51f5cf834a/jax/_src/lax/linalg.py#L1803
519+
520+
And from ..[2] which describes a solution in the square matrix case.
521+
522+
References
523+
----------
524+
.. [1] Townsend, James. "Differentiating the qr decomposition." online draft https://j-towns.github.io/papers/qr-derivative.pdf (2018)
525+
.. [2] Sebastian F. Walter , Lutz Lehmann & René Lamour. "On evaluating higher-order derivatives
526+
of the QR decomposition of tall matrices with full column rank in forward and reverse mode algorithmic differentiation",
527+
Optimization Methods and Software, 27:2, 391-403, DOI: 10.1080/10556788.2011.610454
528+
"""
529+
530+
(A,) = (cast(ptb.TensorVariable, x) for x in inputs)
531+
*_, m, n = A.type.shape
532+
533+
def _H(x: ptb.TensorVariable):
534+
return x.conj().T
535+
536+
def _copyutl(x: ptb.TensorVariable):
537+
return ptb.triu(x, k=0) + _H(ptb.triu(x, k=1))
538+
539+
if self.mode == "raw" or (self.mode == "complete" and m != n):
540+
raise NotImplementedError("Gradient of qr not implemented")
541+
542+
elif m < n:
543+
raise NotImplementedError(
544+
"Gradient of qr not implemented for m x n matrices with m < n"
545+
)
546+
547+
elif self.mode == "r":
548+
# We need all the components of the QR to compute the gradient of A even if we only
549+
# use the upper triangular component in the cost function.
550+
Q, R = qr(A, mode="reduced")
551+
dR = cast(ptb.TensorVariable, output_grads[0])
552+
R_dRt = R @ _H(dR)
553+
Rinvt = _H(inv(R))
554+
A_bar = Q @ ((ptb.tril(R_dRt - _H(R_dRt), k=-1)) @ Rinvt + dR)
555+
return [A_bar]
556+
557+
else:
558+
Q, R = (cast(ptb.TensorVariable, x) for x in outputs)
559+
560+
new_output_grads = []
561+
is_disconnected = [
562+
isinstance(x.type, DisconnectedType) for x in output_grads
563+
]
564+
if all(is_disconnected):
565+
# This should never be reached by Pytensor
566+
return [DisconnectedType()()] # pragma: no cover
567+
568+
for disconnected, output_grad, output in zip(
569+
is_disconnected, output_grads, [Q, R], strict=True
570+
):
571+
if disconnected:
572+
new_output_grads.append(output.zeros_like())
573+
else:
574+
new_output_grads.append(output_grad)
575+
576+
(dQ, dR) = (cast(ptb.TensorVariable, x) for x in new_output_grads)
577+
578+
Rinvt = _H(inv(R))
579+
Qt_dQ = _H(Q) @ dQ
580+
R_dRt = R @ _H(dR)
581+
A_bar = (
582+
Q @ (ptb.tril(R_dRt - _H(R_dRt), k=-1) - _copyutl(Qt_dQ)) + dQ
583+
) @ Rinvt + Q @ dR
584+
585+
return [A_bar]
586+
515587

516588
def qr(a, mode="reduced"):
517589
"""

tests/tensor/test_nlinalg.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,18 @@ def test_qr_modes():
152152
assert "name 'complete' is not defined" in str(e)
153153

154154

155+
@pytest.mark.parametrize("shape", [(3, 3), (6, 3)], ids=["shape=(3, 3)", "shape=(6,3)"])
156+
@pytest.mark.parametrize("output", [0, 1], ids=["Q", "R"])
157+
def test_qr_grad(shape, output):
158+
rng = np.random.default_rng(utt.fetch_seed())
159+
160+
def _test_fn(x):
161+
return qr(x, mode="reduced")[output]
162+
163+
a = rng.standard_normal(shape).astype(config.floatX)
164+
utt.verify_grad(_test_fn, [a], rng=np.random)
165+
166+
155167
class TestSvd(utt.InferShapeTester):
156168
op_class = SVD
157169

0 commit comments

Comments
 (0)