Skip to content

Commit 5208662

Browse files
Handle right-multiplication case
1 parent d9d7566 commit 5208662

File tree

2 files changed

+19
-21
lines changed

2 files changed

+19
-21
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def local_0_dot_x(fgraph, node):
155155
@node_rewriter([Dot])
156156
def local_block_diag_dot_to_dot_block_diag(fgraph, node):
157157
r"""
158-
Perform the rewrite ``dot(block_diag(A, B), C) -> block_diag(dot(A, C), dot(B, C))``
158+
Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))``
159159
160160
BlockDiag results in the creation of a matrix of shape ``(n1 * n2, m1 * m2)``. Because dot has complexity
161161
of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than
@@ -189,25 +189,18 @@ def check_for_block_diag(x):
189189
new_output = join(0, *new_components)
190190
elif not check_for_block_diag(x) and check_for_block_diag(y):
191191
components = y.owner.inputs
192-
new_components = [op(x, component) for component in components]
193-
new_output = join(0, *new_components)
194-
195-
# Case 2: Both inputs are BlockDiagonal. Here we can proceed only if the static shapes are known and identical. In
196-
# that case, blockdiag(a,b) @ blockdiag(c, d) = blockdiag(a @ c, b @ d), but this is not true in the general case
197-
elif any(shape is None for shape in (*x.type.shape, *y.type.shape)):
198-
return None
199-
elif x.ndim == y.ndim and all(
200-
x_shape == y_shape for x_shape, y_shape in zip(x.type.shape, y.type.shape)
201-
):
202-
x_components = x.owner.inputs
203-
y_components = y.owner.inputs
192+
x_splits = split(
193+
x,
194+
splits_size=[component.shape[0] for component in components],
195+
n_splits=len(components),
196+
axis=1,
197+
)
204198

205-
if len(x_components) != len(y_components):
206-
return None
199+
new_components = [
200+
op(x_split, component) for component, x_split in zip(components, x_splits)
201+
]
202+
new_output = join(1, *new_components)
207203

208-
new_output = BlockDiagonal(len(x_components))(
209-
*[op(x_comp, y_comp) for x_comp, y_comp in zip(x_components, y_components)]
210-
)
211204
else:
212205
return None
213206

tests/tensor/rewriting/test_math.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4748,17 +4748,22 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
47484748
)
47494749

47504750

4751-
def test_local_block_diag_dot_to_dot_block_diag():
4751+
@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"])
4752+
def test_local_block_diag_dot_to_dot_block_diag(left_multiply):
47524753
"""
47534754
Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
47544755
"""
47554756
a = tensor("a", shape=(4, 2))
47564757
b = tensor("b", shape=(2, 4))
47574758
c = tensor("c", shape=(4, 4))
4758-
d = tensor("d", shape=(10,))
4759+
d = tensor("d", shape=(10, 10))
47594760

47604761
x = pt.linalg.block_diag(a, b, c)
4761-
out = x @ d
4762+
4763+
if left_multiply:
4764+
out = x @ d
4765+
else:
4766+
out = d @ x
47624767

47634768
fn = pytensor.function([a, b, c, d], out)
47644769
assert not any(

0 commit comments

Comments
 (0)