Skip to content

Commit 0e47b7d

Browse files
author
Etienne Duchesne
committed
qr decomposition gradient: add symbolic shape check
1 parent ee9aaa2 commit 0e47b7d

File tree

2 files changed

+29
-9
lines changed

2 files changed

+29
-9
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pytensor.graph.basic import Apply
1212
from pytensor.graph.op import Op
1313
from pytensor.npy_2_compat import normalize_axis_tuple
14+
from pytensor.raise_op import Assert
1415
from pytensor.tensor import TensorLike
1516
from pytensor.tensor import basic as ptb
1617
from pytensor.tensor import math as ptm
@@ -530,26 +531,40 @@ def L_op(self, inputs, outputs, output_grads):
530531
from pytensor.tensor.slinalg import solve_triangular
531532

532533
(A,) = (cast(ptb.TensorVariable, x) for x in inputs)
533-
*_, m, n = A.type.shape
534534

535535
def _H(x: ptb.TensorVariable):
536536
return x.conj().mT
537537

538538
def _copyutl(x: ptb.TensorVariable):
539539
return ptb.triu(x, k=0) + _H(ptb.triu(x, k=1))
540540

541-
if self.mode == "raw" or (self.mode == "complete" and m != n):
542-
raise NotImplementedError("Gradient of qr not implemented")
541+
if self.mode == "raw":
542+
raise NotImplementedError("Gradient of qr not implemented for mode=raw")
543543

544-
elif m < n:
545-
raise NotImplementedError(
546-
"Gradient of qr not implemented for m x n matrices with m < n"
544+
elif self.mode == "complete":
545+
Q, R = (cast(ptb.TensorVariable, x) for x in outputs)
546+
qr_assert_op = Assert(
547+
"Gradient of qr not implemented for m x n matrices with m != n and mode=complete"
547548
)
549+
R = qr_assert_op(R, ptm.eq(R.shape[0], R.shape[1]))
548550

549551
elif self.mode == "r":
552+
qr_assert_op = Assert(
553+
"Gradient of qr not implemented for m x n matrices with m < n and mode=r"
554+
)
555+
A = qr_assert_op(A, ptm.ge(A.shape[0], A.shape[1]))
550556
# We need all the components of the QR to compute the gradient of A even if we only
551557
# use the upper triangular component in the cost function.
552558
Q, R = qr(A, mode="reduced")
559+
560+
else:
561+
Q, R = (cast(ptb.TensorVariable, x) for x in outputs)
562+
qr_assert_op = Assert(
563+
"Gradient of qr not implemented for m x n matrices with m < n and mode=reduced"
564+
)
565+
R = qr_assert_op(R, ptm.eq(R.shape[0], R.shape[1]))
566+
567+
if self.mode == "r":
553568
dR = cast(ptb.TensorVariable, output_grads[0])
554569
R_dRt = R @ _H(dR)
555570
M = ptb.tril(R_dRt - _H(R_dRt), k=-1)
@@ -558,8 +573,6 @@ def _copyutl(x: ptb.TensorVariable):
558573
return [A_bar]
559574

560575
else:
561-
Q, R = (cast(ptb.TensorVariable, x) for x in outputs)
562-
563576
new_output_grads = []
564577
is_disconnected = [
565578
isinstance(x.type, DisconnectedType) for x in output_grads

tests/tensor/test_nlinalg.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,20 @@ def _test_fn(x, case=2, mode="reduced"):
190190
m, n = shape
191191
a = rng.standard_normal(shape).astype(config.floatX)
192192

193-
if m < n or (mode == "complete" and m != n) or mode == "raw":
193+
if mode == "raw":
194194
with pytest.raises(NotImplementedError):
195195
utt.verify_grad(
196196
partial(_test_fn, case=gradient_test_case, mode=mode),
197197
[a],
198198
rng=np.random,
199199
)
200+
elif m < n or (mode == "complete" and m != n):
201+
with pytest.raises(AssertionError):
202+
utt.verify_grad(
203+
partial(_test_fn, case=gradient_test_case, mode=mode),
204+
[a],
205+
rng=np.random,
206+
)
200207

201208
else:
202209
utt.verify_grad(

0 commit comments

Comments
 (0)