Skip to content

Commit 3b66eba

Browse files
Respond to feedback
1 parent c5137d7 commit 3b66eba

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,8 @@ def local_0_dot_x(fgraph, node):
170170
return [constant_zero]
171171

172172

173-
@register_canonicalize
174-
@register_specialize
175173
@register_stabilize
176-
@node_rewriter([Dot])
174+
@node_rewriter([Blockwise])
177175
def local_block_diag_dot_to_dot_block_diag(fgraph, node):
178176
r"""
179177
Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))``
@@ -182,8 +180,8 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node):
182180
of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than
183181
a single dot on the larger matrix.
184182
"""
185-
x, y = node.inputs
186-
op = node.op
183+
if not isinstance(node.op.core_op, BlockDiagonal):
184+
return
187185

188186
def check_for_block_diag(x):
189187
return x.owner and (
@@ -192,6 +190,15 @@ def check_for_block_diag(x):
192190
and isinstance(x.owner.op.core_op, BlockDiagonal)
193191
)
194192

193+
# Check that the BlockDiagonal is an input to a Dot node:
194+
clients = list(get_clients_at_depth(fgraph, node, depth=1))
195+
if not clients or len(clients) > 1 or not isinstance(clients[0].op, Dot):
196+
return
197+
198+
[dot_node] = clients
199+
op = dot_node.op
200+
x, y = dot_node.inputs
201+
195202
if not (check_for_block_diag(x) or check_for_block_diag(y)):
196203
return None
197204

@@ -208,6 +215,7 @@ def check_for_block_diag(x):
208215
op(component, y_split) for component, y_split in zip(components, y_splits)
209216
]
210217
new_output = join(0, *new_components)
218+
211219
elif not check_for_block_diag(x) and check_for_block_diag(y):
212220
components = y.owner.inputs
213221
x_splits = split(
@@ -222,11 +230,14 @@ def check_for_block_diag(x):
222230
]
223231
new_output = join(1, *new_components)
224232

233+
# Case 2: Both inputs are BlockDiagonal. Do nothing
225234
else:
235+
# TODO: If shapes are statically known and all components have equal shapes, we could rewrite
236+
# this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)])
226237
return None
227238

228239
copy_stack_trace(node.outputs[0], new_output)
229-
return [new_output]
240+
return {dot_node.outputs[0]: new_output}
230241

231242

232243
@register_canonicalize

0 commit comments

Comments
 (0)