@@ -170,10 +170,8 @@ def local_0_dot_x(fgraph, node):
170170 return [constant_zero ]
171171
172172
173- @register_canonicalize
174- @register_specialize
175173@register_stabilize
176- @node_rewriter ([Dot ])
174+ @node_rewriter ([Blockwise ])
177175def local_block_diag_dot_to_dot_block_diag (fgraph , node ):
178176 r"""
179177 Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))``
@@ -182,8 +180,8 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node):
182180 of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than
183181 a single dot on the larger matrix.
184182 """
185- x , y = node .inputs
186- op = node . op
183+ if not isinstance ( node .op . core_op , BlockDiagonal ):
184+ return
187185
188186 def check_for_block_diag (x ):
189187 return x .owner and (
@@ -192,6 +190,15 @@ def check_for_block_diag(x):
192190 and isinstance (x .owner .op .core_op , BlockDiagonal )
193191 )
194192
193+ # Check that the BlockDiagonal is an input to a Dot node:
194+ clients = list (get_clients_at_depth (fgraph , node , depth = 1 ))
195+ if not clients or len (clients ) > 1 or not isinstance (clients [0 ].op , Dot ):
196+ return
197+
198+ [dot_node ] = clients
199+ op = dot_node .op
200+ x , y = dot_node .inputs
201+
195202 if not (check_for_block_diag (x ) or check_for_block_diag (y )):
196203 return None
197204
@@ -208,6 +215,7 @@ def check_for_block_diag(x):
208215 op (component , y_split ) for component , y_split in zip (components , y_splits )
209216 ]
210217 new_output = join (0 , * new_components )
218+
211219 elif not check_for_block_diag (x ) and check_for_block_diag (y ):
212220 components = y .owner .inputs
213221 x_splits = split (
@@ -222,11 +230,14 @@ def check_for_block_diag(x):
222230 ]
223231 new_output = join (1 , * new_components )
224232
233+ # Case 2: Both inputs are BlockDiagonal. Do nothing
225234 else :
235+ # TODO: If shapes are statically known and all components have equal shapes, we could rewrite
236+ # this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)])
226237 return None
227238
228239 copy_stack_trace (node .outputs [0 ], new_output )
229- return [ new_output ]
240+ return { dot_node . outputs [ 0 ]: new_output }
230241
231242
232243@register_canonicalize
0 commit comments