Skip to content

Commit 0f76af9

Browse files
Rewrite solves involving kron to eliminate kron
1 parent 892a8f0 commit 0f76af9

File tree

2 files changed

+179
-2
lines changed

2 files changed

+179
-2
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -588,9 +588,11 @@ def svd_uv_merge(fgraph, node):
588588
@node_rewriter([Blockwise])
589589
def rewrite_inv_inv(fgraph, node):
590590
"""
591-
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once.
591+
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))),
592+
we get back our original input without having to compute inverse once.
592593
593-
Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten.
594+
Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to
595+
be simply rewritten.
594596
595597
Parameters
596598
----------
@@ -855,6 +857,69 @@ def rewrite_det_kronecker(fgraph, node):
855857
return [det_final]
856858

857859

860+
@register_canonicalize
861+
@register_stabilize
862+
@node_rewriter([Blockwise])
863+
def rewrite_solve_kron_to_solve(fgraph, node):
864+
"""
865+
Given a linear system of the form:
866+
867+
.. math::
868+
869+
(A \\otimes B) x = y
870+
871+
Define :math:`\text{vec}(x)` as a column-wise raveling operation (``x.reshape(-1, order='F')`` in code). Further,
872+
define :math:`y = \text{vec}(Y)`. Then the above expression can be rewritten as:
873+
874+
.. math::
875+
876+
x = \text{vec}(B^{-1} Y A^{-T})
877+
878+
Eliminating the kronecker product from the expression.
879+
"""
880+
881+
if not isinstance(node.op.core_op, SolveBase):
882+
return
883+
884+
solve_op = node.op
885+
props_dict = solve_op.core_op._props_dict()
886+
b_ndim = props_dict["b_ndim"]
887+
888+
A, b = node.inputs
889+
890+
if not A.owner or not (
891+
isinstance(A.owner.op, KroneckerProduct)
892+
or isinstance(A.owner.op, Blockwise)
893+
and isinstance(A.owner.op.core_op, KroneckerProduct)
894+
):
895+
return
896+
897+
x1, x2 = A.owner.inputs
898+
899+
m, n = x1.shape[-2], x2.shape[-2]
900+
batch_shapes = x1.shape[:-2]
901+
902+
if b_ndim == 1:
903+
# The rewritten expression will reshape B to be 2d. The easiest way to handle this is to just make a new
904+
# solve node with n_ndim = 2
905+
props_dict["b_ndim"] = 2
906+
new_solve_op = Blockwise(type(solve_op.core_op)(**props_dict))
907+
B = b.reshape((*batch_shapes, m, n))
908+
res = new_solve_op(x1, new_solve_op(x2, B.mT).mT).reshape((*batch_shapes, -1))
909+
910+
else:
911+
# If b_ndim is 2, we need to keep track of the original right-most dimension of b as an additional
912+
# batch dimension
913+
b_batch = b.shape[-1]
914+
B = pt.moveaxis(b, -1, 0).reshape((b_batch, *batch_shapes, m, n))
915+
916+
res = pt.moveaxis(solve_op(x1, solve_op(x2, B.mT).mT), 0, -1).reshape(
917+
(*batch_shapes, -1, b_batch)
918+
)
919+
920+
return [res]
921+
922+
858923
@register_canonicalize
859924
@register_stabilize
860925
@node_rewriter([Blockwise])

tests/tensor/rewriting/test_linalg.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,118 @@ def test_slogdet_kronecker_rewrite():
828828
)
829829

830830

831+
@pytest.mark.parametrize("add_batch", [True, False], ids=["batched", "not_batched"])
832+
@pytest.mark.parametrize("b_ndim", [1, 2], ids=["b_ndim_1", "b_ndim_2"])
833+
@pytest.mark.parametrize(
834+
"solve_op, solve_kwargs",
835+
[
836+
(pt.linalg.solve, {"assume_a": "gen"}),
837+
(pt.linalg.solve, {"assume_a": "pos"}),
838+
(pt.linalg.solve, {"assume_a": "upper triangular"}),
839+
],
840+
ids=["general", "positive definite", "triangular"],
841+
)
842+
def test_rewrite_solve_kron_to_solve(add_batch, b_ndim, solve_op, solve_kwargs):
843+
a_shape = (5, 5) if not add_batch else (3, 5, 5)
844+
b_shape = (5, 5) if not add_batch else (3, 5, 5)
845+
A, B = pt.tensor("A", shape=a_shape), pt.tensor("B", shape=b_shape)
846+
847+
m, n = a_shape[-2], b_shape[-2]
848+
y_shape = (m * n,)
849+
if b_ndim == 2:
850+
y_shape = (m * n, 5)
851+
if add_batch:
852+
y_shape = (3, *y_shape)
853+
854+
y = pt.tensor("y", shape=y_shape)
855+
C = pt.vectorize(pt.linalg.kron, "(i,j),(k,l)->(m,n)")(A, B)
856+
857+
x = solve_op(C, y, **solve_kwargs, b_ndim=b_ndim)
858+
859+
def count_kron_ops(fn):
860+
return sum(
861+
[
862+
isinstance(node.op, KroneckerProduct)
863+
or (
864+
isinstance(node.op, Blockwise)
865+
and isinstance(node.op.core_op, KroneckerProduct)
866+
)
867+
for node in fn.maker.fgraph.apply_nodes
868+
]
869+
)
870+
871+
fn_expected = pytensor.function(
872+
[A, B, y], x, mode=get_default_mode().excluding("rewrite_solve_kron_to_solve")
873+
)
874+
assert count_kron_ops(fn_expected) == 1
875+
876+
fn = pytensor.function([A, B, y], x)
877+
assert (
878+
count_kron_ops(fn) == 0
879+
), "Rewrite did not apply, KroneckerProduct found in the graph"
880+
881+
rng = np.random.default_rng(sum(map(ord, "Go away Kron!")))
882+
a_val = rng.normal(size=a_shape).astype(config.floatX)
883+
b_val = rng.normal(size=b_shape).astype(config.floatX)
884+
y_val = rng.normal(size=y_shape).astype(config.floatX)
885+
886+
if solve_kwargs["assume_a"] == "pos":
887+
a_val = a_val @ a_val.mT
888+
b_val = b_val @ b_val.mT
889+
elif solve_kwargs["assume_a"] == "upper triangular":
890+
a_idx = np.tril_indices(n=a_shape[-2], m=a_shape[-1], k=-1)
891+
b_idx = np.tril_indices(n=b_shape[-2], m=b_shape[-1], k=-1)
892+
893+
if len(a_shape) > 2:
894+
a_idx = (slice(None, None), *a_idx)
895+
if len(b_shape) > 2:
896+
b_idx = (slice(None, None), *b_idx)
897+
898+
a_val[a_idx] = 0
899+
b_val[b_idx] = 0
900+
901+
expected = fn_expected(a_val, b_val, y_val)
902+
result = fn(a_val, b_val, y_val)
903+
904+
np.testing.assert_allclose(
905+
expected,
906+
result,
907+
atol=1e-8 if config.floatX == "float64" else 1e-5,
908+
rtol=1e-8 if config.floatX == "float64" else 1e-5,
909+
)
910+
911+
912+
@pytest.mark.parametrize(
913+
"a_shape, b_shape",
914+
[((5, 5), (5, 5)), ((50, 50), (50, 50)), ((100, 100), (100, 100))],
915+
ids=["small", "medium", "large"],
916+
)
917+
@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])
918+
def test_rewrite_solve_kron_to_solve_benchmark(a_shape, b_shape, rewrite, benchmark):
919+
A, B = pt.tensor("A", shape=a_shape), pt.tensor("B", shape=b_shape)
920+
C = pt.linalg.kron(A, B)
921+
922+
m, n = a_shape[-2], b_shape[-2]
923+
has_batch = len(a_shape) == 3
924+
y_shape = (a_shape[0], m * n) if has_batch else (m * n,)
925+
y = pt.tensor("y", shape=y_shape)
926+
x = pt.linalg.solve(C, y, b_ndim=1)
927+
928+
rng = np.random.default_rng(sum(map(ord, "Go away Kron!")))
929+
a_val = rng.normal(size=a_shape).astype(config.floatX)
930+
b_val = rng.normal(size=b_shape).astype(config.floatX)
931+
y_val = rng.normal(size=y_shape).astype(config.floatX)
932+
933+
mode = (
934+
get_default_mode()
935+
if rewrite
936+
else get_default_mode().excluding("rewrite_solve_kron_to_solve")
937+
)
938+
939+
fn = pytensor.function([A, B, y], x, mode=mode)
940+
benchmark(fn, a_val, b_val, y_val)
941+
942+
831943
def test_cholesky_eye_rewrite():
832944
x = pt.eye(10)
833945
L = pt.linalg.cholesky(x)

0 commit comments

Comments
 (0)