Skip to content

Commit 3ceece7

Browse files
Add tests for batched LU
1 parent 997ad75 commit 3ceece7

File tree

2 files changed

+50
-44
lines changed

2 files changed

+50
-44
lines changed

pytensor/tensor/slinalg.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,8 @@ class LU(Op):
391391
def __init__(
392392
self, *, permute_l=False, overwrite_a=False, check_finite=True, p_indices=False
393393
):
394+
if permute_l and p_indices:
395+
raise ValueError("Only one of permute_l and p_indices can be True")
394396
self.permute_l = permute_l
395397
self.check_finite = check_finite
396398
self.p_indices = p_indices
@@ -432,12 +434,12 @@ def make_node(self, x):
432434
if self.permute_l:
433435
# In this case, L is actually P @ L
434436
return Apply(self, inputs=[x], outputs=[L, U])
435-
elif self.p_indices:
436-
p = tensor(shape=(x.type.shape[0],), dtype=p_dtype)
437-
return Apply(self, inputs=[x], outputs=[p, L, U])
438-
else:
439-
P = tensor(shape=x.type.shape, dtype=p_dtype)
440-
return Apply(self, inputs=[x], outputs=[P, L, U])
437+
if self.p_indices:
438+
p_indices = tensor(shape=(x.type.shape[0],), dtype=p_dtype)
439+
return Apply(self, inputs=[x], outputs=[p_indices, L, U])
440+
441+
P = tensor(shape=x.type.shape, dtype=p_dtype)
442+
return Apply(self, inputs=[x], outputs=[P, L, U])
441443

442444
def perform(self, node, inputs, outputs):
443445
[A] = inputs
@@ -479,30 +481,24 @@ def L_op(
479481
A = cast(TensorVariable, A)
480482

481483
if self.permute_l:
482-
PL_bar, U_bar = output_grads
484+
# P has no gradient contribution (by assumption...), so PL_bar is the same as L_bar
485+
L_bar, U_bar = output_grads
483486

484487
# TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient
485-
P, L, U = lu( # type: ignore
488+
# We need L, not PL. It's not possible to recover it from PL, though. So we need to do a new forward pass
489+
P_or_indices, L, U = lu( # type: ignore
486490
A, permute_l=False, check_finite=self.check_finite, p_indices=False
487491
)
488492

489-
# Permutation matrix is orthogonal
490-
L_bar = (
491-
P.T @ PL_bar
492-
if not isinstance(PL_bar.type, DisconnectedType)
493-
else pt.zeros_like(A)
494-
)
495-
496-
elif self.p_indices:
497-
p, L, U = outputs
498-
499-
# TODO: rewrite to p_indices = False for graphs where we need to compute the gradient
500-
P = pt.eye(A.shape[-1])[p]
501-
_, L_bar, U_bar = output_grads
502493
else:
503-
P, L, U = outputs
494+
# In both other cases, there are 3 outputs. The first output will either be the permutation index itself,
495+
# or indices that can be used to reconstruct the permutation matrix.
496+
P_or_indices, L, U = outputs
504497
_, L_bar, U_bar = output_grads
505498

499+
L = pytensor.printing.Print("L")(L)
500+
U = pytensor.printing.Print("U")(U)
501+
506502
L_bar = (
507503
L_bar if not isinstance(L_bar.type, DisconnectedType) else pt.zeros_like(A)
508504
)
@@ -513,9 +509,17 @@ def L_op(
513509
x1 = ptb.tril(L.T @ L_bar, k=-1)
514510
x2 = ptb.triu(U_bar @ U.T)
515511

516-
L_inv_x = solve_triangular(L.T, x1 + x2, lower=False, unit_diagonal=True)
517-
A_bar = P @ solve_triangular(U, L_inv_x.T, lower=False).T
512+
LT_inv_x = solve_triangular(L.T, x1 + x2, lower=False, unit_diagonal=True)
518513

514+
# Where B = P.T @ A is a change of variable to avoid the permutation matrix in the gradient derivation
515+
B_bar = solve_triangular(U, LT_inv_x.T, lower=False).T
516+
517+
if not self.p_indices:
518+
A_bar = P_or_indices @ B_bar
519+
else:
520+
A_bar = B_bar[P_or_indices]
521+
522+
A_bar = pytensor.printing.Print("A_bar")(A_bar)
519523
return [A_bar]
520524

521525

@@ -556,16 +560,14 @@ def lu(
556560
U: TensorVariable
557561
Upper triangular matrix
558562
"""
559-
op = cast(
563+
return cast(
560564
tuple[TensorVariable, TensorVariable, TensorVariable]
561565
| tuple[TensorVariable, TensorVariable],
562566
Blockwise(
563-
LU(permute_l=permute_l, check_finite=check_finite, p_indices=p_indices)
564-
),
567+
LU(permute_l=permute_l, p_indices=p_indices, check_finite=check_finite)
568+
)(a),
565569
)
566570

567-
return op(a)
568-
569571

570572
class SolveTriangular(SolveBase):
571573
"""Solve a system of linear equations."""

tests/tensor/test_slinalg.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -474,8 +474,11 @@ def test_solve_dtype(self):
474474
assert x.dtype == x_result.dtype, (A_dtype, b_dtype)
475475

476476

477-
@pytest.mark.parametrize("permute_l", [True, False], ids=["permute_l", "no_permute_l"])
478-
@pytest.mark.parametrize("p_indices", [True, False], ids=["p_indices", "no_p_indices"])
477+
@pytest.mark.parametrize(
478+
"permute_l, p_indices",
479+
[(False, True), (True, False), (False, False)],
480+
ids=["PL", "p_indices", "P"],
481+
)
479482
@pytest.mark.parametrize("complex", [False, True], ids=["real", "complex"])
480483
@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"])
481484
def test_lu_decomposition(
@@ -517,29 +520,30 @@ def test_lu_decomposition(
517520
np.testing.assert_allclose(a, b)
518521

519522

520-
@pytest.mark.parametrize("grad_case", [0, 1, 2], ids=["U_only", "L_only", "U_and_L"])
521-
@pytest.mark.parametrize("permute_l", [True, False])
522-
@pytest.mark.parametrize("p_indices", [True, False])
523+
@pytest.mark.parametrize(
524+
"grad_case", [0, 1, 2], ids=["dU_only", "dL_only", "dU_and_dL"]
525+
)
526+
@pytest.mark.parametrize(
527+
"permute_l, p_indices",
528+
[(True, False), (False, True), (False, False)],
529+
ids=["PL", "p_indices", "P"],
530+
)
523531
@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"])
524532
def test_lu_grad(grad_case, permute_l, p_indices, shape):
525533
rng = np.random.default_rng(utt.fetch_seed())
526-
A_value = rng.normal(size=shape)
534+
A_value = rng.normal(size=shape).astype(config.floatX)
527535

528536
def f_pt(A):
529-
out = lu(A, permute_l=permute_l, p_indices=p_indices)
530-
531-
if permute_l:
532-
L, U = out
533-
else:
534-
_, L, U = out
537+
# lu returns either (P_or_index, L, U) or (PL, U), depending on settings
538+
out = lu(A, permute_l=permute_l, p_indices=p_indices, check_finite=False)
535539

536540
match grad_case:
537541
case 0:
538-
return U.sum()
542+
return out[-1].sum()
539543
case 1:
540-
return L.sum()
544+
return out[-2].sum()
541545
case 2:
542-
return U.sum() + L.sum()
546+
return out[-1].sum() + out[-2].sum()
543547

544548
utt.verify_grad(f_pt, [A_value], rng=rng)
545549

0 commit comments

Comments
 (0)