Skip to content

Commit 3ab2bb0

Browse files
Handle case with multiple clients
1 parent 1a89309 commit 3ab2bb0

File tree

2 files changed

+51
-47
lines changed

2 files changed

+51
-47
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -170,53 +170,54 @@ def check_for_block_diag(x):
170170
)
171171

172172
# Check that the BlockDiagonal is an input to a Dot node:
173-
clients = list(get_clients_at_depth(fgraph, node, depth=1))
174-
if not clients or len(clients) > 1 or not isinstance(clients[0].op, Dot):
175-
return
173+
for client in get_clients_at_depth(fgraph, node, depth=1):
174+
if not isinstance(client.op, Dot):
175+
return
176176

177-
[dot_node] = clients
178-
op = dot_node.op
179-
x, y = dot_node.inputs
177+
op = client.op
178+
x, y = client.inputs
180179

181-
if not (check_for_block_diag(x) or check_for_block_diag(y)):
182-
return None
180+
if not (check_for_block_diag(x) or check_for_block_diag(y)):
181+
return None
183182

184-
# Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the
185-
# non-block diagonal, and return a new block diagonal
186-
if check_for_block_diag(x) and not check_for_block_diag(y):
187-
components = x.owner.inputs
188-
y_splits = split(
189-
y,
190-
splits_size=[component.shape[-1] for component in components],
191-
n_splits=len(components),
192-
)
193-
new_components = [
194-
op(component, y_split) for component, y_split in zip(components, y_splits)
195-
]
196-
new_output = join(0, *new_components)
197-
198-
elif not check_for_block_diag(x) and check_for_block_diag(y):
199-
components = y.owner.inputs
200-
x_splits = split(
201-
x,
202-
splits_size=[component.shape[0] for component in components],
203-
n_splits=len(components),
204-
axis=1,
205-
)
183+
# Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the
184+
# non-block diagonal, and return a new block diagonal
185+
if check_for_block_diag(x) and not check_for_block_diag(y):
186+
components = x.owner.inputs
187+
y_splits = split(
188+
y,
189+
splits_size=[component.shape[-1] for component in components],
190+
n_splits=len(components),
191+
)
192+
new_components = [
193+
op(component, y_split)
194+
for component, y_split in zip(components, y_splits)
195+
]
196+
new_output = join(0, *new_components)
197+
198+
elif not check_for_block_diag(x) and check_for_block_diag(y):
199+
components = y.owner.inputs
200+
x_splits = split(
201+
x,
202+
splits_size=[component.shape[0] for component in components],
203+
n_splits=len(components),
204+
axis=1,
205+
)
206206

207-
new_components = [
208-
op(x_split, component) for component, x_split in zip(components, x_splits)
209-
]
210-
new_output = join(1, *new_components)
207+
new_components = [
208+
op(x_split, component)
209+
for component, x_split in zip(components, x_splits)
210+
]
211+
new_output = join(1, *new_components)
211212

212-
# Case 2: Both inputs are BlockDiagonal. Do nothing
213-
else:
214-
# TODO: If shapes are statically known and all components have equal shapes, we could rewrite
215-
# this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)])
216-
return None
213+
# Case 2: Both inputs are BlockDiagonal. Do nothing
214+
else:
215+
# TODO: If shapes are statically known and all components have equal shapes, we could rewrite
216+
# this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)])
217+
return None
217218

218-
copy_stack_trace(node.outputs[0], new_output)
219-
return {dot_node.outputs[0]: new_output}
219+
copy_stack_trace(node.outputs[0], new_output)
220+
return {client.outputs[0]: new_output}
220221

221222

222223
@register_canonicalize

tests/tensor/rewriting/test_math.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4757,21 +4757,23 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply):
47574757
b = tensor("b", shape=(2, 4))
47584758
c = tensor("c", shape=(4, 4))
47594759
d = tensor("d", shape=(10, 10))
4760+
e = tensor("e", shape=(10, 10))
47604761

47614762
x = pt.linalg.block_diag(a, b, c)
47624763

4764+
# Test multiple clients are all rewritten
47634765
if left_multiply:
4764-
out = x @ d
4766+
out = [x @ d, x @ e]
47654767
else:
4766-
out = d @ x
4768+
out = [d @ x, e @ x]
47674769

4768-
fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode)
4770+
fn = pytensor.function([a, b, c, d, e], out, mode=rewrite_mode)
47694771
assert not any(
47704772
isinstance(node.op, BlockDiagonal) for node in fn.maker.fgraph.toposort()
47714773
)
47724774

47734775
fn_expected = pytensor.function(
4774-
[a, b, c, d],
4776+
[a, b, c, d, e],
47754777
out,
47764778
mode=rewrite_mode.excluding("local_block_diag_dot_to_dot_block_diag"),
47774779
)
@@ -4781,10 +4783,11 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply):
47814783
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
47824784
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
47834785
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
4786+
e_val = rng.normal(size=e.type.shape).astype(e.type.dtype)
47844787

47854788
np.testing.assert_allclose(
4786-
fn(a_val, b_val, c_val, d_val),
4787-
fn_expected(a_val, b_val, c_val, d_val),
4789+
fn(a_val, b_val, c_val, d_val, e_val),
4790+
fn_expected(a_val, b_val, c_val, d_val, e_val),
47884791
atol=1e-6 if config.floatX == "float32" else 1e-12,
47894792
rtol=1e-6 if config.floatX == "float32" else 1e-12,
47904793
)

0 commit comments

Comments
 (0)