Skip to content

Commit 14ee8e2

Browse files
Respond to feedback
1 parent b619e87 commit 14ee8e2

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,8 @@ def local_0_dot_x(fgraph, node):
149149
return [zeros((x.shape[0], y.shape[1]), dtype=node.outputs[0].type.dtype)]
150150

151151

152-
@register_canonicalize
153-
@register_specialize
154152
@register_stabilize
155-
@node_rewriter([Dot])
153+
@node_rewriter([Blockwise])
156154
def local_block_diag_dot_to_dot_block_diag(fgraph, node):
157155
r"""
158156
Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))``
@@ -161,8 +159,8 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node):
161159
of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than
162160
a single dot on the larger matrix.
163161
"""
164-
x, y = node.inputs
165-
op = node.op
162+
if not isinstance(node.op.core_op, BlockDiagonal):
163+
return
166164

167165
def check_for_block_diag(x):
168166
return x.owner and (
@@ -171,6 +169,15 @@ def check_for_block_diag(x):
171169
and isinstance(x.owner.op.core_op, BlockDiagonal)
172170
)
173171

172+
# Check that the BlockDiagonal is an input to a Dot node:
173+
clients = list(get_clients_at_depth(fgraph, node, depth=1))
174+
if not clients or len(clients) > 1 or not isinstance(clients[0].op, Dot):
175+
return
176+
177+
[dot_node] = clients
178+
op = dot_node.op
179+
x, y = dot_node.inputs
180+
174181
if not (check_for_block_diag(x) or check_for_block_diag(y)):
175182
return None
176183

@@ -187,6 +194,7 @@ def check_for_block_diag(x):
187194
op(component, y_split) for component, y_split in zip(components, y_splits)
188195
]
189196
new_output = join(0, *new_components)
197+
190198
elif not check_for_block_diag(x) and check_for_block_diag(y):
191199
components = y.owner.inputs
192200
x_splits = split(
@@ -201,11 +209,14 @@ def check_for_block_diag(x):
201209
]
202210
new_output = join(1, *new_components)
203211

212+
# Case 2: Both inputs are BlockDiagonal. Do nothing
204213
else:
214+
# TODO: If shapes are statically known and all components have equal shapes, we could rewrite
215+
# this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)])
205216
return None
206217

207218
copy_stack_trace(node.outputs[0], new_output)
208-
return [new_output]
219+
return {dot_node.outputs[0]: new_output}
209220

210221

211222
@register_canonicalize

0 commit comments

Comments
 (0)