Skip to content

Commit 5be594a

Browse files
Rewrite solves involving kron to eliminate kron
1 parent 892a8f0 commit 5be594a

File tree

2 files changed

+194
-2
lines changed

2 files changed

+194
-2
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -588,9 +588,11 @@ def svd_uv_merge(fgraph, node):
588588
@node_rewriter([Blockwise])
589589
def rewrite_inv_inv(fgraph, node):
590590
"""
591-
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once.
591+
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))),
592+
we get back our original input without having to compute inverse once.
592593
593-
Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten.
594+
Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to
595+
be simply rewritten.
594596
595597
Parameters
596598
----------
@@ -855,6 +857,69 @@ def rewrite_det_kronecker(fgraph, node):
855857
return [det_final]
856858

857859

860+
@register_canonicalize
861+
@register_stabilize
862+
@node_rewriter([Blockwise])
863+
def rewrite_solve_kron_to_solve(fgraph, node):
864+
"""
865+
Given a linear system of the form:
866+
867+
.. math::
868+
869+
(A \\otimes B) x = y
870+
871+
Define :math:`\text{vec}(x)` as a column-wise raveling operation (``x.reshape(-1, order='F')`` in code). Further,
872+
define :math:`y = \text{vec}(Y)`. Then the above expression can be rewritten as:
873+
874+
.. math::
875+
876+
x = \text{vec}(B^{-1} Y A^{-T})
877+
878+
Eliminating the kronecker product from the expression.
879+
"""
880+
881+
if not isinstance(node.op.core_op, SolveBase):
882+
return
883+
884+
solve_op = node.op
885+
props_dict = solve_op.core_op._props_dict()
886+
b_ndim = props_dict["b_ndim"]
887+
888+
A, b = node.inputs
889+
890+
if not A.owner or not (
891+
isinstance(A.owner.op, KroneckerProduct)
892+
or isinstance(A.owner.op, Blockwise)
893+
and isinstance(A.owner.op.core_op, KroneckerProduct)
894+
):
895+
return
896+
897+
x1, x2 = A.owner.inputs
898+
899+
m, n = x1.shape[-2], x2.shape[-2]
900+
batch_shapes = x1.shape[:-2]
901+
902+
if b_ndim == 1:
903+
# The rewritten expression will reshape B to be 2d. The easiest way to handle this is to just make a new
904+
# solve node with n_ndim = 2
905+
props_dict["b_ndim"] = 2
906+
new_solve_op = Blockwise(type(solve_op.core_op)(**props_dict))
907+
B = b.reshape((*batch_shapes, m, n))
908+
res = new_solve_op(x1, new_solve_op(x2, B.mT).mT).reshape((*batch_shapes, -1))
909+
910+
else:
911+
# If b_ndim is 2, we need to keep track of the original right-most dimension of b as an additional
912+
# batch dimension
913+
b_batch = b.shape[-1]
914+
B = pt.moveaxis(b, -1, 0).reshape((b_batch, *batch_shapes, m, n))
915+
916+
res = pt.moveaxis(solve_op(x1, solve_op(x2, B.mT).mT), 0, -1).reshape(
917+
(*batch_shapes, -1, b_batch)
918+
)
919+
920+
return [res]
921+
922+
858923
@register_canonicalize
859924
@register_stabilize
860925
@node_rewriter([Blockwise])

tests/tensor/rewriting/test_linalg.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,133 @@ def test_slogdet_kronecker_rewrite():
828828
)
829829

830830

831+
@pytest.mark.parametrize("add_batch", [True, False], ids=["batched", "not_batched"])
832+
@pytest.mark.parametrize("b_ndim", [1, 2], ids=["b_ndim_1", "b_ndim_2"])
833+
@pytest.mark.parametrize(
834+
"solve_op, solve_kwargs",
835+
[
836+
(pt.linalg.solve, {"assume_a": "gen"}),
837+
(pt.linalg.solve, {"assume_a": "pos"}),
838+
(pt.linalg.solve, {"assume_a": "upper triangular"}),
839+
],
840+
ids=["general", "positive definite", "triangular"],
841+
)
842+
def test_rewrite_solve_kron_to_solve(add_batch, b_ndim, solve_op, solve_kwargs):
843+
# A and B have different shapes to make the test more interesting, but both need to be square matrices, otherwise
844+
# the rewrite is invalid.
845+
a_shape = (3, 3) if not add_batch else (2, 3, 3)
846+
b_shape = (2, 2) if not add_batch else (2, 2, 2)
847+
A, B = pt.tensor("A", shape=a_shape), pt.tensor("B", shape=b_shape)
848+
849+
m, n = a_shape[-2], b_shape[-2]
850+
y_shape = (m * n,)
851+
if b_ndim == 2:
852+
y_shape = (m * n, 3)
853+
if add_batch:
854+
y_shape = (2, *y_shape)
855+
856+
y = pt.tensor("y", shape=y_shape)
857+
C = pt.vectorize(pt.linalg.kron, "(i,j),(k,l)->(m,n)")(A, B)
858+
859+
x = solve_op(C, y, **solve_kwargs, b_ndim=b_ndim)
860+
861+
def count_kron_ops(fn):
862+
return sum(
863+
[
864+
isinstance(node.op, KroneckerProduct)
865+
or (
866+
isinstance(node.op, Blockwise)
867+
and isinstance(node.op.core_op, KroneckerProduct)
868+
)
869+
for node in fn.maker.fgraph.apply_nodes
870+
]
871+
)
872+
873+
fn_expected = pytensor.function(
874+
[A, B, y], x, mode=get_default_mode().excluding("rewrite_solve_kron_to_solve")
875+
)
876+
assert count_kron_ops(fn_expected) == 1
877+
878+
fn = pytensor.function([A, B, y], x)
879+
assert (
880+
count_kron_ops(fn) == 0
881+
), "Rewrite did not apply, KroneckerProduct found in the graph"
882+
883+
rng = np.random.default_rng(sum(map(ord, "Go away Kron!")))
884+
a_val = rng.normal(size=a_shape)
885+
b_val = rng.normal(size=b_shape)
886+
y_val = rng.normal(size=y_shape)
887+
888+
if solve_kwargs["assume_a"] == "pos":
889+
a_val = a_val @ np.moveaxis(a_val, -2, -1)
890+
b_val = b_val @ np.moveaxis(b_val, -2, -1)
891+
elif solve_kwargs["assume_a"] == "upper triangular":
892+
a_idx = np.tril_indices(n=a_shape[-2], m=a_shape[-1], k=-1)
893+
b_idx = np.tril_indices(n=b_shape[-2], m=b_shape[-1], k=-1)
894+
895+
if len(a_shape) > 2:
896+
a_idx = (slice(None, None), *a_idx)
897+
if len(b_shape) > 2:
898+
b_idx = (slice(None, None), *b_idx)
899+
900+
a_val[a_idx] = 0
901+
b_val[b_idx] = 0
902+
903+
a_val = a_val.astype(config.floatX)
904+
b_val = b_val.astype(config.floatX)
905+
y_val = y_val.astype(config.floatX)
906+
907+
expected = fn_expected(a_val, b_val, y_val)
908+
result = fn(a_val, b_val, y_val)
909+
910+
if config.floatX == "float64":
911+
tol = 1e-8
912+
elif config.floatX == "float32" and not solve_kwargs["assume_a"] == "pos":
913+
tol = 1e-4
914+
else:
915+
# Precision needs to be extremely low for the assume_a = pos test to pass in float32 mode. I don't have a
916+
# good theory of why. Skipping this case would also be an option.
917+
tol = 1e-2
918+
919+
np.testing.assert_allclose(
920+
expected,
921+
result,
922+
atol=tol,
923+
rtol=tol,
924+
)
925+
926+
927+
@pytest.mark.parametrize(
928+
"a_shape, b_shape",
929+
[((5, 5), (5, 5)), ((50, 50), (50, 50)), ((100, 100), (100, 100))],
930+
ids=["small", "medium", "large"],
931+
)
932+
@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])
933+
def test_rewrite_solve_kron_to_solve_benchmark(a_shape, b_shape, rewrite, benchmark):
934+
A, B = pt.tensor("A", shape=a_shape), pt.tensor("B", shape=b_shape)
935+
C = pt.linalg.kron(A, B)
936+
937+
m, n = a_shape[-2], b_shape[-2]
938+
has_batch = len(a_shape) == 3
939+
y_shape = (a_shape[0], m * n) if has_batch else (m * n,)
940+
y = pt.tensor("y", shape=y_shape)
941+
x = pt.linalg.solve(C, y, b_ndim=1)
942+
943+
rng = np.random.default_rng(sum(map(ord, "Go away Kron!")))
944+
a_val = rng.normal(size=a_shape).astype(config.floatX)
945+
b_val = rng.normal(size=b_shape).astype(config.floatX)
946+
y_val = rng.normal(size=y_shape).astype(config.floatX)
947+
948+
mode = (
949+
get_default_mode()
950+
if rewrite
951+
else get_default_mode().excluding("rewrite_solve_kron_to_solve")
952+
)
953+
954+
fn = pytensor.function([A, B, y], x, mode=mode)
955+
benchmark(fn, a_val, b_val, y_val)
956+
957+
831958
def test_cholesky_eye_rewrite():
832959
x = pt.eye(10)
833960
L = pt.linalg.cholesky(x)

0 commit comments

Comments
 (0)