@@ -170,10 +170,8 @@ def local_0_dot_x(fgraph, node):
170
170
return [constant_zero ]
171
171
172
172
173
- @register_canonicalize
174
- @register_specialize
175
173
@register_stabilize
176
- @node_rewriter ([Dot ])
174
+ @node_rewriter ([Blockwise ])
177
175
def local_block_diag_dot_to_dot_block_diag (fgraph , node ):
178
176
r"""
179
177
Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))``
@@ -182,8 +180,8 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node):
182
180
of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than
183
181
a single dot on the larger matrix.
184
182
"""
185
- x , y = node .inputs
186
- op = node . op
183
+ if not isinstance ( node .op . core_op , BlockDiagonal ):
184
+ return
187
185
188
186
def check_for_block_diag (x ):
189
187
return x .owner and (
@@ -192,6 +190,15 @@ def check_for_block_diag(x):
192
190
and isinstance (x .owner .op .core_op , BlockDiagonal )
193
191
)
194
192
193
+ # Check that the BlockDiagonal is an input to a Dot node:
194
+ clients = list (get_clients_at_depth (fgraph , node , depth = 1 ))
195
+ if not clients or len (clients ) > 1 or not isinstance (clients [0 ].op , Dot ):
196
+ return
197
+
198
+ [dot_node ] = clients
199
+ op = dot_node .op
200
+ x , y = dot_node .inputs
201
+
195
202
if not (check_for_block_diag (x ) or check_for_block_diag (y )):
196
203
return None
197
204
@@ -208,6 +215,7 @@ def check_for_block_diag(x):
208
215
op (component , y_split ) for component , y_split in zip (components , y_splits )
209
216
]
210
217
new_output = join (0 , * new_components )
218
+
211
219
elif not check_for_block_diag (x ) and check_for_block_diag (y ):
212
220
components = y .owner .inputs
213
221
x_splits = split (
@@ -222,11 +230,14 @@ def check_for_block_diag(x):
222
230
]
223
231
new_output = join (1 , * new_components )
224
232
233
+ # Case 2: Both inputs are BlockDiagonal. Do nothing
225
234
else :
235
+ # TODO: If shapes are statically known and all components have equal shapes, we could rewrite
236
+ # this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)])
226
237
return None
227
238
228
239
copy_stack_trace (node .outputs [0 ], new_output )
229
- return [ new_output ]
240
+ return { dot_node . outputs [ 0 ]: new_output }
230
241
231
242
232
243
@register_canonicalize
0 commit comments