Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,30 @@
MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv)


@register_canonicalize
@node_rewriter([BlockDiagonal])
def fuse_blockdiagonal(fgraph, node):
"""Fuse nested BlockDiagonal ops into a single BlockDiagonal."""

new_inputs = []
changed = False

for inp in node.inputs:
if inp.owner and isinstance(inp.owner.op, BlockDiagonal):
new_inputs.extend(inp.owner.inputs)
changed = True
else:
new_inputs.append(inp)

if changed:
fused_op = BlockDiagonal(len(new_inputs))
new_output = fused_op(*new_inputs)
copy_stack_trace(node.outputs[0], new_output)
return [new_output]

return None


def is_matrix_transpose(x: TensorVariable) -> bool:
"""Check if a variable corresponds to a transpose of the last two axes"""
node = x.owner
Expand Down
65 changes: 64 additions & 1 deletion tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pytensor import tensor as pt
from pytensor.compile import get_default_mode
from pytensor.configdefaults import config
from pytensor.graph import ancestors
from pytensor.graph import FunctionGraph, ancestors
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.tensor import swapaxes
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
Expand Down Expand Up @@ -44,6 +44,69 @@
from tests.test_rop import break_op


def test_nested_blockdiag_fusion():
x = pt.tensor("x", shape=(3, 3))
y = pt.tensor("y", shape=(3, 3))
z = pt.tensor("z", shape=(3, 3))

inner = BlockDiagonal(2)(x, y)
outer = BlockDiagonal(2)(inner, z)

initial_count = sum(
1
for node in ancestors([outer])
if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal)
)
assert initial_count == 2, "Setup failed: expected 2 nested BlockDiagonal ops"

fgraph = FunctionGraph(inputs=[x, y, z], outputs=[outer])
rewrite_graph(fgraph, include=("fast_run", "blockdiag_fusion"))

fused_nodes = [
node for node in fgraph.toposort() if isinstance(node.op, BlockDiagonal)
]
assert len(fused_nodes) == 1, "Nested BlockDiagonal ops were not fused"

fused_op = fused_nodes[0].op
assert fused_op.n_inputs == 3, f"Expected n_inputs=3, got {fused_op.n_inputs}"

out_shape = fgraph.outputs[0].type.shape
assert out_shape == (9, 9), f"Unexpected fused output shape: {out_shape}"


def test_deeply_nested_blockdiag_fusion():
x = pt.tensor("x", shape=(3, 3))
y = pt.tensor("y", shape=(3, 3))
z = pt.tensor("z", shape=(3, 3))
w = pt.tensor("w", shape=(3, 3))

inner1 = BlockDiagonal(2)(x, y)
inner2 = BlockDiagonal(2)(inner1, z)
outer = BlockDiagonal(2)(inner2, w)

fgraph = FunctionGraph(inputs=[x, y, z, w], outputs=[outer])
rewrite_graph(fgraph, include=("fast_run", "blockdiag_fusion"))

fused_block_diag_nodes = [
node for node in fgraph.apply_nodes if isinstance(node.op, BlockDiagonal)
]
assert len(fused_block_diag_nodes) == 1, (
f"Expected 1 fused BlockDiagonal, got {len(fused_block_diag_nodes)}"
)

fused_block_diag_op = fused_block_diag_nodes[0].op

assert fused_block_diag_op.n_inputs == 4, (
f"Expected n_inputs=4 after fusion, got {fused_block_diag_op.n_inputs}"
)

out_shape = fgraph.outputs[0].type.shape
expected_shape = (12, 12) # 4 blocks of (3x3)
assert out_shape == expected_shape, (
f"Unexpected fused output shape: expected {expected_shape}, got {out_shape}"
)


def test_matrix_inverse_rop_lop():
rtol = 1e-7 if config.floatX == "float64" else 1e-5
mx = matrix("mx")
Expand Down
Loading