Skip to content

Commit 2b8125a

Browse files
Rewrite solves involving kron to eliminate kron
1 parent 892a8f0 commit 2b8125a

File tree

2 files changed

+144
-2
lines changed

2 files changed

+144
-2
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -588,9 +588,11 @@ def svd_uv_merge(fgraph, node):
588588
@node_rewriter([Blockwise])
589589
def rewrite_inv_inv(fgraph, node):
590590
"""
591-
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once.
591+
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))),
592+
we get back our original input without having to compute inverse once.
592593
593-
Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten.
594+
Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to
595+
be simply rewritten.
594596
595597
Parameters
596598
----------
@@ -855,6 +857,59 @@ def rewrite_det_kronecker(fgraph, node):
855857
return [det_final]
856858

857859

860+
@register_canonicalize
861+
@register_stabilize
862+
@node_rewriter([Blockwise])
863+
def rewrite_solve_kron_to_solve(fgraph, node):
864+
"""
865+
Given a linear system of the form:
866+
867+
.. math:
868+
869+
(A \\otimes B) x = y
870+
871+
Define :math:`\text{vec}(x)` as a column-wise raveling operation (``x.reshape(-1, order='F")`` in code). Further,
872+
define :math:`y = \text{vec}(Y)`. Then the above expression can be rewritten as:
873+
874+
..math::
875+
876+
x = \text{vec}(B^{-1} Y A^{-T})
877+
878+
Eliminating the kronecker product from the expression.
879+
"""
880+
881+
if not isinstance(node.op.core_op, SolveBase):
882+
return
883+
884+
solve_op = node.op
885+
props_dict = solve_op.core_op._props_dict()
886+
887+
if props_dict["b_ndim"] != 1:
888+
# The formula used in the rewrite requires that b is a vector, otherwise it's not clear how to reshape it
889+
# to conform with the components of the kronecker product.
890+
return
891+
892+
A, b = node.inputs
893+
894+
if not A.owner or not (
895+
isinstance(A.owner.op, KroneckerProduct)
896+
or isinstance(A.owner.op, Blockwise)
897+
and isinstance(A.owner.op.core_op, KroneckerProduct)
898+
):
899+
return
900+
901+
x1, x2 = A.owner.inputs
902+
903+
m, n = x1.shape[-2], x2.shape[-2]
904+
batch_shapes = x1.shape[:-2]
905+
B = b.reshape((*batch_shapes, m, n))
906+
907+
props_dict["b_ndim"] = 2
908+
new_solve_op = Blockwise(type(solve_op.core_op)(**props_dict))
909+
910+
return [new_solve_op(x1, new_solve_op(x2, B.mT).mT).reshape((*batch_shapes, -1))]
911+
912+
858913
@register_canonicalize
859914
@register_stabilize
860915
@node_rewriter([Blockwise])

tests/tensor/rewriting/test_linalg.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,93 @@ def test_slogdet_kronecker_rewrite():
828828
)
829829

830830

831+
@pytest.mark.parametrize(
832+
"a_shape, b_shape",
833+
[((5, 5), (5, 5)), ((5, 5), (3, 3)), ((3, 4, 4), (3, 3, 3))],
834+
ids=["same", "diff", "batch_diff"],
835+
)
836+
@pytest.mark.parametrize(
837+
"solve_op, solve_kwargs",
838+
[
839+
(pt.linalg.solve, {"assume_a": "gen"}),
840+
(pt.linalg.solve, {"assume_a": "pos"}),
841+
(pt.linalg.solve, {"assume_a": "upper triangular"}),
842+
],
843+
ids=["general", "positive definite", "triangular"],
844+
)
845+
def test_rewrite_solve_kron_to_solve(a_shape, b_shape, solve_op, solve_kwargs):
846+
A, B = pt.tensor("A", shape=a_shape), pt.tensor("B", shape=b_shape)
847+
848+
m, n = a_shape[-2], b_shape[-2]
849+
has_batch = len(a_shape) == 3
850+
y_shape = (a_shape[0], m * n) if has_batch else (m * n,)
851+
y = pt.tensor("y", shape=y_shape)
852+
C = pt.vectorize(pt.linalg.kron, "(i,j),(k,l)->(m,n)")(A, B)
853+
854+
x = solve_op(C, y, **solve_kwargs, b_ndim=1)
855+
fn_expected = pytensor.function(
856+
[A, B, y], x, mode=get_default_mode().excluding("rewrite_solve_kron_to_solve")
857+
)
858+
assert (
859+
sum(
860+
[
861+
isinstance(node.op, KroneckerProduct)
862+
or (
863+
isinstance(node.op, Blockwise)
864+
and isinstance(node.op.core_op, KroneckerProduct)
865+
)
866+
for node in fn_expected.maker.fgraph.apply_nodes
867+
]
868+
)
869+
== 1
870+
)
871+
872+
fn = pytensor.function([A, B, y], x)
873+
assert (
874+
sum(
875+
[
876+
isinstance(node.op, KroneckerProduct)
877+
or (
878+
isinstance(node.op, Blockwise)
879+
and isinstance(node.op.core_op, KroneckerProduct)
880+
)
881+
for node in fn.maker.fgraph.apply_nodes
882+
]
883+
)
884+
== 0
885+
)
886+
887+
rng = np.random.default_rng(sum(map(ord, "Go away Kron!")))
888+
a_val = rng.normal(size=a_shape).astype(config.floatX)
889+
b_val = rng.normal(size=b_shape).astype(config.floatX)
890+
y_val = rng.normal(size=y_shape).astype(config.floatX)
891+
892+
if solve_kwargs["assume_a"] == "pos":
893+
a_val = a_val @ a_val.mT
894+
b_val = b_val @ b_val.mT
895+
elif solve_kwargs["assume_a"] == "upper triangular":
896+
a_idx = np.tril_indices(n=a_shape[-2], m=a_shape[-1], k=-1)
897+
b_idx = np.tril_indices(n=b_shape[-2], m=b_shape[-1], k=-1)
898+
899+
if len(a_shape) > 2:
900+
a_idx = (slice(None, None), *a_idx)
901+
if len(b_shape) > 2:
902+
b_idx = (slice(None, None), *b_idx)
903+
904+
a_val[a_idx] = 0
905+
b_val[b_idx] = 0
906+
907+
expected = fn_expected(a_val, b_val, y_val)
908+
result = fn(a_val, b_val, y_val)
909+
910+
np.testing.assert_allclose(
911+
expected,
912+
result,
913+
atol=1e-8 if config.floatX == "float64" else 1e-5,
914+
rtol=1e-8 if config.floatX == "float64" else 1e-5,
915+
)
916+
917+
831918
def test_cholesky_eye_rewrite():
832919
x = pt.eye(10)
833920
L = pt.linalg.cholesky(x)

0 commit comments

Comments
 (0)