Skip to content

Commit ee9aaa2

Browse files
author
Etienne Duchesne
committed
replace inv by solve_triangular and improve test coverage in qr gradient
1 parent a187464 commit ee9aaa2

File tree

2 files changed

+55
-13
lines changed

2 files changed

+55
-13
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -527,11 +527,13 @@ def L_op(self, inputs, outputs, output_grads):
527527
Optimization Methods and Software, 27:2, 391-403, DOI: 10.1080/10556788.2011.610454
528528
"""
529529

530+
from pytensor.tensor.slinalg import solve_triangular
531+
530532
(A,) = (cast(ptb.TensorVariable, x) for x in inputs)
531533
*_, m, n = A.type.shape
532534

533535
def _H(x: ptb.TensorVariable):
534-
return x.conj().T
536+
return x.conj().mT
535537

536538
def _copyutl(x: ptb.TensorVariable):
537539
return ptb.triu(x, k=0) + _H(ptb.triu(x, k=1))
@@ -550,8 +552,9 @@ def _copyutl(x: ptb.TensorVariable):
550552
Q, R = qr(A, mode="reduced")
551553
dR = cast(ptb.TensorVariable, output_grads[0])
552554
R_dRt = R @ _H(dR)
553-
Rinvt = _H(inv(R))
554-
A_bar = Q @ ((ptb.tril(R_dRt - _H(R_dRt), k=-1)) @ Rinvt + dR)
555+
M = ptb.tril(R_dRt - _H(R_dRt), k=-1)
556+
M_Rinvt = _H(solve_triangular(R, _H(M)))
557+
A_bar = Q @ (M_Rinvt + dR)
555558
return [A_bar]
556559

557560
else:
@@ -575,12 +578,11 @@ def _copyutl(x: ptb.TensorVariable):
575578

576579
(dQ, dR) = (cast(ptb.TensorVariable, x) for x in new_output_grads)
577580

578-
Rinvt = _H(inv(R))
579581
Qt_dQ = _H(Q) @ dQ
580582
R_dRt = R @ _H(dR)
581-
A_bar = (
582-
Q @ (ptb.tril(R_dRt - _H(R_dRt), k=-1) - _copyutl(Qt_dQ)) + dQ
583-
) @ Rinvt + Q @ dR
583+
M = Q @ (ptb.tril(R_dRt - _H(R_dRt), k=-1) - _copyutl(Qt_dQ)) + dQ
584+
M_Rinvt = _H(solve_triangular(R, _H(M)))
585+
A_bar = M_Rinvt + Q @ dR
584586

585587
return [A_bar]
586588

tests/tensor/test_nlinalg.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,16 +152,56 @@ def test_qr_modes():
152152
assert "name 'complete' is not defined" in str(e)
153153

154154

155-
@pytest.mark.parametrize("shape", [(3, 3), (6, 3)], ids=["shape=(3, 3)", "shape=(6,3)"])
156-
@pytest.mark.parametrize("output", [0, 1], ids=["Q", "R"])
157-
def test_qr_grad(shape, output):
155+
@pytest.mark.parametrize(
156+
"shape, gradient_test_case, mode",
157+
(
158+
[(s, c, "reduced") for s in [(3, 3), (6, 3), (3, 6)] for c in [0, 1, 2]]
159+
+ [(s, c, "complete") for s in [(3, 3), (6, 3), (3, 6)] for c in [0, 1, 2]]
160+
+ [(s, 0, "r") for s in [(3, 3), (6, 3), (3, 6)]]
161+
+ [((3, 3), 0, "raw")]
162+
),
163+
ids=(
164+
[
165+
f"shape={s}, gradient_test_case={c}, mode=reduced"
166+
for s in [(3, 3), (6, 3), (3, 6)]
167+
for c in ["Q", "R", "both"]
168+
]
169+
+ [
170+
f"shape={s}, gradient_test_case={c}, mode=complete"
171+
for s in [(3, 3), (6, 3), (3, 6)]
172+
for c in ["Q", "R", "both"]
173+
]
174+
+ [f"shape={s}, gradient_test_case=R, mode=r" for s in [(3, 3), (6, 3), (3, 6)]]
175+
+ ["shape=(3, 3), gradient_test_case=Q, mode=raw"]
176+
),
177+
)
178+
def test_qr_grad(shape, gradient_test_case, mode):
158179
rng = np.random.default_rng(utt.fetch_seed())
159180

160-
def _test_fn(x):
161-
return qr(x, mode="reduced")[output]
181+
def _test_fn(x, case=2, mode="reduced"):
182+
if case == 0:
183+
return qr(x, mode=mode)[0].sum()
184+
elif case == 1:
185+
return qr(x, mode=mode)[1].sum()
186+
elif case == 2:
187+
Q, R = qr(x, mode=mode)
188+
return Q.sum() + R.sum()
162189

190+
m, n = shape
163191
a = rng.standard_normal(shape).astype(config.floatX)
164-
utt.verify_grad(_test_fn, [a], rng=np.random)
192+
193+
if m < n or (mode == "complete" and m != n) or mode == "raw":
194+
with pytest.raises(NotImplementedError):
195+
utt.verify_grad(
196+
partial(_test_fn, case=gradient_test_case, mode=mode),
197+
[a],
198+
rng=np.random,
199+
)
200+
201+
else:
202+
utt.verify_grad(
203+
partial(_test_fn, case=gradient_test_case, mode=mode), [a], rng=np.random
204+
)
165205

166206

167207
class TestSvd(utt.InferShapeTester):

0 commit comments

Comments
 (0)