Skip to content

Commit 9e5e765

Browse files
author
Etienne Duchesne
committed
qr decomposition gradient: extend gradient to more input shapes
* for mode=reduced or mode=r, all input shapes are accepted * for mode=complete, shapes m x n where m <= n are accepted
1 parent 0e47b7d commit 9e5e765

File tree

2 files changed

+32
-42
lines changed

2 files changed

+32
-42
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 30 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy as np
77

8+
import pytensor.tensor as pt
89
from pytensor import scalar as ps
910
from pytensor.compile.builders import OpFromGraph
1011
from pytensor.gradient import DisconnectedType
@@ -515,64 +516,43 @@ def perform(self, node, inputs, outputs):
515516

516517
def L_op(self, inputs, outputs, output_grads):
517518
"""
518-
Reverse-mode gradient of the QR function. Adapted from ..[1], which is used in the forward-mode implementation in jax here:
519-
https://github.com/jax-ml/jax/blob/54691b125ab4b6f88c751dae460e4d51f5cf834a/jax/_src/lax/linalg.py#L1803
520-
521-
And from ..[2] which describes a solution in the square matrix case.
519+
Reverse-mode gradient of the QR function.
522520
523521
References
524522
----------
525-
.. [1] Townsend, James. "Differentiating the qr decomposition." online draft https://j-towns.github.io/papers/qr-derivative.pdf (2018)
526-
.. [2] Sebastian F. Walter , Lutz Lehmann & René Lamour. "On evaluating higher-order derivatives
527-
of the QR decomposition of tall matrices with full column rank in forward and reverse mode algorithmic differentiation",
528-
Optimization Methods and Software, 27:2, 391-403, DOI: 10.1080/10556788.2011.610454
523+
.. [1] Jinguo Liu. "Linear Algebra Autodiff (complex valued)", blog post https://giggleliu.github.io/posts/2019-04-02-einsumbp/
524+
.. [2] Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang. "Differentiable Programming Tensor Networks", arXiv:1903.09650v2
529525
"""
530526

531527
from pytensor.tensor.slinalg import solve_triangular
532528

533529
(A,) = (cast(ptb.TensorVariable, x) for x in inputs)
530+
m, n = A.shape
534531

535532
def _H(x: ptb.TensorVariable):
536533
return x.conj().mT
537534

538-
def _copyutl(x: ptb.TensorVariable):
539-
return ptb.triu(x, k=0) + _H(ptb.triu(x, k=1))
535+
def _copyltu(x: ptb.TensorVariable):
536+
return ptb.tril(x, k=0) + _H(ptb.tril(x, k=-1))
540537

541538
if self.mode == "raw":
542539
raise NotImplementedError("Gradient of qr not implemented for mode=raw")
543540

544-
elif self.mode == "complete":
545-
Q, R = (cast(ptb.TensorVariable, x) for x in outputs)
546-
qr_assert_op = Assert(
547-
"Gradient of qr not implemented for m x n matrices with m != n and mode=complete"
548-
)
549-
R = qr_assert_op(R, ptm.eq(R.shape[0], R.shape[1]))
550-
551541
elif self.mode == "r":
552-
qr_assert_op = Assert(
553-
"Gradient of qr not implemented for m x n matrices with m < n and mode=r"
554-
)
555-
A = qr_assert_op(A, ptm.ge(A.shape[0], A.shape[1]))
556542
# We need all the components of the QR to compute the gradient of A even if we only
557543
# use the upper triangular component in the cost function.
558544
Q, R = qr(A, mode="reduced")
545+
dQ = Q.zeros_like()
546+
dR = cast(ptb.TensorVariable, output_grads[0])
559547

560548
else:
561549
Q, R = (cast(ptb.TensorVariable, x) for x in outputs)
562-
qr_assert_op = Assert(
563-
"Gradient of qr not implemented for m x n matrices with m < n and mode=reduced"
564-
)
565-
R = qr_assert_op(R, ptm.eq(R.shape[0], R.shape[1]))
566-
567-
if self.mode == "r":
568-
dR = cast(ptb.TensorVariable, output_grads[0])
569-
R_dRt = R @ _H(dR)
570-
M = ptb.tril(R_dRt - _H(R_dRt), k=-1)
571-
M_Rinvt = _H(solve_triangular(R, _H(M)))
572-
A_bar = Q @ (M_Rinvt + dR)
573-
return [A_bar]
550+
if self.mode == "complete":
551+
qr_assert_op = Assert(
552+
"Gradient of qr not implemented for m x n matrices with m > n and mode=complete"
553+
)
554+
R = qr_assert_op(R, ptm.le(m, n))
574555

575-
else:
576556
new_output_grads = []
577557
is_disconnected = [
578558
isinstance(x.type, DisconnectedType) for x in output_grads
@@ -591,13 +571,22 @@ def _copyutl(x: ptb.TensorVariable):
591571

592572
(dQ, dR) = (cast(ptb.TensorVariable, x) for x in new_output_grads)
593573

594-
Qt_dQ = _H(Q) @ dQ
595-
R_dRt = R @ _H(dR)
596-
M = Q @ (ptb.tril(R_dRt - _H(R_dRt), k=-1) - _copyutl(Qt_dQ)) + dQ
597-
M_Rinvt = _H(solve_triangular(R, _H(M)))
598-
A_bar = M_Rinvt + Q @ dR
599-
600-
return [A_bar]
574+
# gradient expression when m >= n
575+
M = R @ _H(dR) - _H(dQ) @ Q
576+
K = dQ + Q @ _copyltu(M)
577+
A_bar_m_ge_n = _H(solve_triangular(R, _H(K)))
578+
579+
# gradient expression when m < n
580+
Y = A[:, m:]
581+
U = R[:, :m]
582+
dU, dV = dR[:, :m], dR[:, m:]
583+
dQ_Yt_dV = dQ + Y @ _H(dV)
584+
M = U @ _H(dU) - _H(dQ_Yt_dV) @ Q
585+
X_bar = _H(solve_triangular(U, _H(dQ_Yt_dV + Q @ _copyltu(M))))
586+
Y_bar = Q @ dV
587+
A_bar_m_lt_n = pt.concatenate([X_bar, Y_bar], axis=1)
588+
589+
return [pt.switch(ptm.ge(m, n), A_bar_m_ge_n, A_bar_m_lt_n)]
601590

602591

603592
def qr(a, mode="reduced"):

tests/tensor/test_nlinalg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,8 @@ def _test_fn(x, case=2, mode="reduced"):
197197
[a],
198198
rng=np.random,
199199
)
200-
elif m < n or (mode == "complete" and m != n):
200+
201+
elif mode == "complete" and m > n:
201202
with pytest.raises(AssertionError):
202203
utt.verify_grad(
203204
partial(_test_fn, case=gradient_test_case, mode=mode),

0 commit comments

Comments
 (0)