Skip to content

Commit ffb71d3

Browse files
Handle right-multiplication case
1 parent 9ad9540 commit ffb71d3

File tree

2 files changed

+19
-21
lines changed

2 files changed

+19
-21
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def local_0_dot_x(fgraph, node):
176176
@node_rewriter([Dot])
177177
def local_block_diag_dot_to_dot_block_diag(fgraph, node):
178178
r"""
179-
Perform the rewrite ``dot(block_diag(A, B), C) -> block_diag(dot(A, C), dot(B, C))``
179+
Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))``
180180
181181
BlockDiag results in the creation of a matrix of shape ``(n1 * n2, m1 * m2)``. Because dot has complexity
182182
of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than
@@ -210,25 +210,18 @@ def check_for_block_diag(x):
210210
new_output = join(0, *new_components)
211211
elif not check_for_block_diag(x) and check_for_block_diag(y):
212212
components = y.owner.inputs
213-
new_components = [op(x, component) for component in components]
214-
new_output = join(0, *new_components)
215-
216-
# Case 2: Both inputs are BlockDiagonal. Here we can proceed only if the static shapes are known and identical. In
217-
# that case, blockdiag(a,b) @ blockdiag(c, d) = blockdiag(a @ c, b @ d), but this is not true in the general case
218-
elif any(shape is None for shape in (*x.type.shape, *y.type.shape)):
219-
return None
220-
elif x.ndim == y.ndim and all(
221-
x_shape == y_shape for x_shape, y_shape in zip(x.type.shape, y.type.shape)
222-
):
223-
x_components = x.owner.inputs
224-
y_components = y.owner.inputs
213+
x_splits = split(
214+
x,
215+
splits_size=[component.shape[0] for component in components],
216+
n_splits=len(components),
217+
axis=1,
218+
)
225219

226-
if len(x_components) != len(y_components):
227-
return None
220+
new_components = [
221+
op(x_split, component) for component, x_split in zip(components, x_splits)
222+
]
223+
new_output = join(1, *new_components)
228224

229-
new_output = BlockDiagonal(len(x_components))(
230-
*[op(x_comp, y_comp) for x_comp, y_comp in zip(x_components, y_components)]
231-
)
232225
else:
233226
return None
234227

tests/tensor/rewriting/test_math.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4657,17 +4657,22 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
46574657
)
46584658

46594659

4660-
def test_local_block_diag_dot_to_dot_block_diag():
4660+
@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"])
4661+
def test_local_block_diag_dot_to_dot_block_diag(left_multiply):
46614662
"""
46624663
Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
46634664
"""
46644665
a = tensor("a", shape=(4, 2))
46654666
b = tensor("b", shape=(2, 4))
46664667
c = tensor("c", shape=(4, 4))
4667-
d = tensor("d", shape=(10,))
4668+
d = tensor("d", shape=(10, 10))
46684669

46694670
x = pt.linalg.block_diag(a, b, c)
4670-
out = x @ d
4671+
4672+
if left_multiply:
4673+
out = x @ d
4674+
else:
4675+
out = d @ x
46714676

46724677
fn = pytensor.function([a, b, c, d], out)
46734678
assert not any(

0 commit comments

Comments
 (0)