Skip to content

Rewrite solves involving kron to eliminate kron #1559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 81 additions & 2 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,9 +588,11 @@ def svd_uv_merge(fgraph, node):
@node_rewriter([Blockwise])
def rewrite_inv_inv(fgraph, node):
"""
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once.
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))),
we get back our original input without having to compute inverse once.

Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten.
Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to
be simply rewritten.

Parameters
----------
Expand Down Expand Up @@ -855,6 +857,83 @@ def rewrite_det_kronecker(fgraph, node):
return [det_final]


@register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe")
@node_rewriter([Blockwise])
def rewrite_solve_kron_to_solve(fgraph, node):
"""
Given a linear system of the form:

.. math::

(A \\otimes B) x = y

Define :math:`\text{vec}(x)` as a column-wise raveling operation (``x.reshape(-1, order='F')`` in code). Further,
define :math:`y = \text{vec}(Y)`. Then the above expression can be rewritten as:

.. math::

x = \text{vec}(B^{-1} Y A^{-T})

Eliminating the kronecker product from the expression.
"""

if not isinstance(node.op.core_op, SolveBase):
return

solve_op = node.op
props_dict = solve_op.core_op._props_dict()
b_ndim = props_dict["b_ndim"]

A, b = node.inputs

if not A.owner or not (
isinstance(A.owner.op, KroneckerProduct)
or isinstance(A.owner.op, Blockwise)
and isinstance(A.owner.op.core_op, KroneckerProduct)
):
Comment on lines +890 to +894
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably broke the parenthesis, but you get the idea. Negate the whole condition that is required and group the Blockwies + KroneckerProduct

Suggested change
if not A.owner or not (
isinstance(A.owner.op, KroneckerProduct)
or isinstance(A.owner.op, Blockwise)
and isinstance(A.owner.op.core_op, KroneckerProduct)
):
if not (A.owner and (
isinstance(A.owner.op, KroneckerProduct)
or (isinstance(A.owner.op, Blockwise)
and isinstance(A.owner.op.core_op, KroneckerProduct)))
):

return

x1, x2 = A.owner.inputs

# If x1 and x2 have statically known core shapes, check that they are square. If not, the rewrite will be invalid.
# We will proceed if they are unknown, but this makes the rewrite shape unsafe.
Comment on lines +899 to +900
Copy link
Member

@ricardoV94 ricardoV94 Aug 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape_unsafe is when a rewrite can mask an originally invalid graph, but it / we aren't allowed to turn a previously valid graph into an invalid one. Is that what's happening here?

x1_core_shapes = x1.type.shape[-2:]
x2_core_shapes = x2.type.shape[-2:]

if (
all(shape is not None for shape in x1_core_shapes)
and x1_core_shapes[-1] != x1_core_shapes[-2]
) or (
all(shape is not None for shape in x2_core_shapes)
and x2_core_shapes[-1] != x2_core_shapes[-2]
):
return None

m, n = x1.shape[-2], x2.shape[-2]
batch_shapes = x1.shape[:-2]
Copy link
Member

@ricardoV94 ricardoV94 Aug 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x1/x2 batch shapes could broadcast in blockwise


if b_ndim == 1:
# The rewritten expression will reshape B to be 2d. The easiest way to handle this is to just make a new
# solve node with n_ndim = 2
props_dict["b_ndim"] = 2
new_solve_op = Blockwise(type(solve_op.core_op)(**props_dict))
B = b.reshape((*batch_shapes, m, n))
res = new_solve_op(x1, new_solve_op(x2, B.mT).mT).reshape((*batch_shapes, -1))

else:
# If b_ndim is 2, we need to keep track of the original right-most dimension of b as an additional
# batch dimension
b_batch = b.shape[-1]
B = pt.moveaxis(b, -1, 0).reshape((b_batch, *batch_shapes, m, n))

res = pt.moveaxis(solve_op(x1, solve_op(x2, B.mT).mT), 0, -1).reshape(
(*batch_shapes, -1, b_batch)
)

return [res]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing copy_stack_trace



@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
Expand Down
150 changes: 150 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,156 @@ def test_slogdet_kronecker_rewrite():
)


def count_kron_ops(fgraph):
return sum(
[
isinstance(node.op, KroneckerProduct)
or (
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, KroneckerProduct)
)
for node in fgraph.apply_nodes
]
)


@pytest.mark.parametrize("add_batch", [True, False], ids=["batched", "not_batched"])
@pytest.mark.parametrize("b_ndim", [1, 2], ids=["b_ndim_1", "b_ndim_2"])
@pytest.mark.parametrize(
"solve_op, solve_kwargs",
[
(pt.linalg.solve, {"assume_a": "gen"}),
(pt.linalg.solve, {"assume_a": "pos"}),
(pt.linalg.solve, {"assume_a": "upper triangular"}),
],
ids=["general", "positive definite", "triangular"],
)
def test_rewrite_solve_kron_to_solve(add_batch, b_ndim, solve_op, solve_kwargs):
# A and B have different shapes to make the test more interesting, but both need to be square matrices, otherwise
# the rewrite is invalid.
a_shape = (3, 3) if not add_batch else (2, 3, 3)
b_shape = (2, 2) if not add_batch else (2, 2, 2)
A, B = pt.tensor("A", shape=a_shape), pt.tensor("B", shape=b_shape)

m, n = a_shape[-2], b_shape[-2]
y_shape = (m * n,)
if b_ndim == 2:
y_shape = (m * n, 3)
if add_batch:
y_shape = (2, *y_shape)

y = pt.tensor("y", shape=y_shape)
C = pt.vectorize(pt.linalg.kron, "(i,j),(k,l)->(m,n)")(A, B)

x = solve_op(C, y, **solve_kwargs, b_ndim=b_ndim)

fn_expected = pytensor.function(
[A, B, y], x, mode=get_default_mode().excluding("rewrite_solve_kron_to_solve")
)
assert count_kron_ops(fn_expected.maker.fgraph) == 1

fn = pytensor.function([A, B, y], x)
assert count_kron_ops(fn.maker.fgraph) == 0

rng = np.random.default_rng(sum(map(ord, "Go away Kron!")))
a_val = rng.normal(size=a_shape)
b_val = rng.normal(size=b_shape)
y_val = rng.normal(size=y_shape)

if solve_kwargs["assume_a"] == "pos":
a_val = a_val @ np.moveaxis(a_val, -2, -1)
b_val = b_val @ np.moveaxis(b_val, -2, -1)
elif solve_kwargs["assume_a"] == "upper triangular":
a_idx = np.tril_indices(n=a_shape[-2], m=a_shape[-1], k=-1)
b_idx = np.tril_indices(n=b_shape[-2], m=b_shape[-1], k=-1)

if len(a_shape) > 2:
a_idx = (slice(None, None), *a_idx)
if len(b_shape) > 2:
b_idx = (slice(None, None), *b_idx)

a_val[a_idx] = 0
b_val[b_idx] = 0

a_val = a_val.astype(config.floatX)
b_val = b_val.astype(config.floatX)
y_val = y_val.astype(config.floatX)

expected = fn_expected(a_val, b_val, y_val)
result = fn(a_val, b_val, y_val)

if config.floatX == "float64":
tol = 1e-8
elif config.floatX == "float32" and not solve_kwargs["assume_a"] == "pos":
tol = 1e-4
else:
# Precision needs to be extremely low for the assume_a = pos test to pass in float32 mode. I don't have a
# good theory of why. Skipping this case would also be an option.
tol = 1e-2

np.testing.assert_allclose(
expected,
result,
atol=tol,
rtol=tol,
)


def test_rewrite_solve_kron_to_solve_not_applied():
# Check that the rewrite is not applied when the component matrices to the kron are static and not square
A = pt.tensor("A", shape=(3, 2))
B = pt.tensor("B", shape=(2, 3))
C = pt.linalg.kron(A, B)

y = pt.vector("y", shape=(6,))
x = pt.linalg.solve(C, y)

fn = pytensor.function([A, B, y], x)

assert count_kron_ops(fn.maker.fgraph) == 1

# If shapes are static, it should always be applied
A = pt.tensor("A", shape=(3, None, None))
B = pt.tensor("B", shape=(3, None, None))
Comment on lines +939 to +941
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Back to the previous comment, is the previous C a valid graph? If so, we can't rewrite and break the graph if we don't know the core shapes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C is valid, because C is square. The "problem" is that we can kron together two non-square matrices and end up with a square one (e.g. kron((4,3), (3,4)) -> (7, 7)). So the rewrite is invalid in this case.

This is another case where we really really wish we had a tag for "square matrix", without having to commit to shapes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wiki seems to suggest Kron(A, B) is only invertible if both A and B are invertible, so you couldn't solve C in the first place if this wasn't the case?

Is that correct? In that case it's fine to have the rewrite when the shapes are unknown (perhaps add a comment?). Otherwise it's not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The theory looks right.

The only issue I guess is that currently, you won't get an error if you have an "invalid" graph like:

A = rng.normal(size=(4, 3))
B = rng.normal(size=(3, 4))

A_pt, B_pt = pt.dmatrices('A', 'B')
y_pt = pt.dvector('y')
C = pt.linalg.kron(A_pt, B_pt)
x = pt.linalg.solve(C, y_pt)

fn = pytensor.function([A_pt, B_pt, y_pt], x)

You get a warning about numerical instability, but it gives you some numbers. Obviously these numbers are just nonsense, but it doesn't error. After the rewrite, you will get a shape error, which might be very surprising for someone who isn't providing a valid graph in the first place?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solve of C doesn't raise for "singular matrix"?

C = pt.linalg.kron(A, B)
y = pt.tensor("y", shape=(None,))
x = pt.linalg.solve(C, y)
fn = pytensor.function([A, B, y], x)

assert count_kron_ops(fn.maker.fgraph) == 0


@pytest.mark.parametrize(
"a_shape, b_shape",
[((5, 5), (5, 5)), ((50, 50), (50, 50)), ((100, 100), (100, 100))],
ids=["small", "medium", "large"],
)
@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])
def test_rewrite_solve_kron_to_solve_benchmark(a_shape, b_shape, rewrite, benchmark):
A, B = pt.tensor("A", shape=a_shape), pt.tensor("B", shape=b_shape)
C = pt.linalg.kron(A, B)

m, n = a_shape[-2], b_shape[-2]
has_batch = len(a_shape) == 3
y_shape = (a_shape[0], m * n) if has_batch else (m * n,)
y = pt.tensor("y", shape=y_shape)
x = pt.linalg.solve(C, y, b_ndim=1)

rng = np.random.default_rng(sum(map(ord, "Go away Kron!")))
a_val = rng.normal(size=a_shape).astype(config.floatX)
b_val = rng.normal(size=b_shape).astype(config.floatX)
y_val = rng.normal(size=y_shape).astype(config.floatX)

mode = (
get_default_mode()
if rewrite
else get_default_mode().excluding("rewrite_solve_kron_to_solve")
)

fn = pytensor.function([A, B, y], x, mode=mode)
benchmark(fn, a_val, b_val, y_val)


def test_cholesky_eye_rewrite():
x = pt.eye(10)
L = pt.linalg.cholesky(x)
Expand Down