Skip to content

Commit 997ad75

Browse files
Blockwise LU
1 parent 331e8ab commit 997ad75

File tree

2 files changed

+29
-16
lines changed

2 files changed

+29
-16
lines changed

pytensor/tensor/slinalg.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import warnings
33
from collections.abc import Sequence
44
from functools import reduce
5-
from typing import Literal, cast, Sequence
5+
from typing import Literal, cast
66

77
import numpy as np
88
import scipy.linalg as scipy_linalg
@@ -426,8 +426,8 @@ def make_node(self, x):
426426
real_dtype = "f" if np.dtype(x.type.dtype).char in "fF" else "d"
427427
p_dtype = "int32" if self.p_indices else np.dtype(real_dtype)
428428

429-
L = tensor(shape=x.type.shape, dtype=real_dtype)
430-
U = tensor(shape=x.type.shape, dtype=real_dtype)
429+
L = tensor(shape=x.type.shape, dtype=x.type.dtype)
430+
U = tensor(shape=x.type.shape, dtype=x.type.dtype)
431431

432432
if self.permute_l:
433433
# In this case, L is actually P @ L
@@ -497,7 +497,7 @@ def L_op(
497497
p, L, U = outputs
498498

499499
# TODO: rewrite to p_indices = False for graphs where we need to compute the gradient
500-
P = pt.eye(A.shape[0])[p]
500+
P = pt.eye(A.shape[-1])[p]
501501
_, L_bar, U_bar = output_grads
502502
else:
503503
P, L, U = outputs
@@ -556,12 +556,16 @@ def lu(
556556
U: TensorVariable
557557
Upper triangular matrix
558558
"""
559-
return cast(
559+
op = cast(
560560
tuple[TensorVariable, TensorVariable, TensorVariable]
561561
| tuple[TensorVariable, TensorVariable],
562-
LU(permute_l=permute_l, check_finite=check_finite, p_indices=p_indices)(a),
562+
Blockwise(
563+
LU(permute_l=permute_l, check_finite=check_finite, p_indices=p_indices)
564+
),
563565
)
564566

567+
return op(a)
568+
565569

566570
class SolveTriangular(SolveBase):
567571
"""Solve a system of linear equations."""

tests/tensor/test_slinalg.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -477,30 +477,38 @@ def test_solve_dtype(self):
477477
@pytest.mark.parametrize("permute_l", [True, False], ids=["permute_l", "no_permute_l"])
478478
@pytest.mark.parametrize("p_indices", [True, False], ids=["p_indices", "no_p_indices"])
479479
@pytest.mark.parametrize("complex", [False, True], ids=["real", "complex"])
480-
def test_lu_decomposition(permute_l, p_indices, complex):
480+
@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"])
481+
def test_lu_decomposition(
482+
permute_l: bool, p_indices: bool, complex: bool, shape: tuple[int]
483+
):
481484
dtype = config.floatX if not complex else f"complex{int(config.floatX[-2:]) * 2}"
482-
A = tensor("A", shape=(None, None), dtype=dtype)
485+
486+
A = tensor("A", shape=shape, dtype=dtype)
483487
out = lu(A, permute_l=permute_l, p_indices=p_indices)
484488

485489
f = pytensor.function([A], out)
486490

487491
rng = np.random.default_rng(utt.fetch_seed())
488-
x = rng.normal(size=(5, 5)).astype(config.floatX)
492+
x = rng.normal(size=shape).astype(config.floatX)
489493
if complex:
490-
x = x + 1j * rng.normal(size=(5, 5)).astype(config.floatX)
494+
x = x + 1j * rng.normal(size=shape).astype(config.floatX)
491495

492496
out = f(x)
493497

494498
if permute_l:
495499
PL, U = out
496-
x_rebuilt = PL @ U
497500
elif p_indices:
498501
p, L, U = out
499-
P = np.eye(5)[p]
500-
x_rebuilt = P @ L @ U
502+
if len(shape) == 2:
503+
P = np.eye(5)[p]
504+
else:
505+
P = np.stack([np.eye(5)[idx] for idx in p])
506+
PL = np.einsum("...nk,...km->...nm", P, L)
501507
else:
502508
P, L, U = out
503-
x_rebuilt = P @ L @ U
509+
PL = np.einsum("...nk,...km->...nm", P, L)
510+
511+
x_rebuilt = np.einsum("...nk,...km->...nm", PL, U)
504512

505513
np.testing.assert_allclose(x, x_rebuilt)
506514
scipy_out = scipy.linalg.lu(x, permute_l=permute_l, p_indices=p_indices)
@@ -512,9 +520,10 @@ def test_lu_decomposition(permute_l, p_indices, complex):
512520
@pytest.mark.parametrize("grad_case", [0, 1, 2], ids=["U_only", "L_only", "U_and_L"])
513521
@pytest.mark.parametrize("permute_l", [True, False])
514522
@pytest.mark.parametrize("p_indices", [True, False])
515-
def test_lu_grad(grad_case, permute_l, p_indices):
523+
@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"])
524+
def test_lu_grad(grad_case, permute_l, p_indices, shape):
516525
rng = np.random.default_rng(utt.fetch_seed())
517-
A_value = rng.normal(size=(5, 5))
526+
A_value = rng.normal(size=shape)
518527

519528
def f_pt(A):
520529
out = lu(A, permute_l=permute_l, p_indices=p_indices)

0 commit comments

Comments
 (0)