@@ -149,10 +149,8 @@ def local_0_dot_x(fgraph, node):
149
149
return [zeros ((x .shape [0 ], y .shape [1 ]), dtype = node .outputs [0 ].type .dtype )]
150
150
151
151
152
- @register_canonicalize
153
- @register_specialize
154
152
@register_stabilize
155
- @node_rewriter ([Dot ])
153
+ @node_rewriter ([Blockwise ])
156
154
def local_block_diag_dot_to_dot_block_diag (fgraph , node ):
157
155
r"""
158
156
Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))``
@@ -161,8 +159,8 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node):
161
159
of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than
162
160
a single dot on the larger matrix.
163
161
"""
164
- x , y = node .inputs
165
- op = node . op
162
+ if not isinstance ( node .op . core_op , BlockDiagonal ):
163
+ return
166
164
167
165
def check_for_block_diag (x ):
168
166
return x .owner and (
@@ -171,6 +169,15 @@ def check_for_block_diag(x):
171
169
and isinstance (x .owner .op .core_op , BlockDiagonal )
172
170
)
173
171
172
+ # Check that the BlockDiagonal is an input to a Dot node:
173
+ clients = list (get_clients_at_depth (fgraph , node , depth = 1 ))
174
+ if not clients or len (clients ) > 1 or not isinstance (clients [0 ].op , Dot ):
175
+ return
176
+
177
+ [dot_node ] = clients
178
+ op = dot_node .op
179
+ x , y = dot_node .inputs
180
+
174
181
if not (check_for_block_diag (x ) or check_for_block_diag (y )):
175
182
return None
176
183
@@ -187,6 +194,7 @@ def check_for_block_diag(x):
187
194
op (component , y_split ) for component , y_split in zip (components , y_splits )
188
195
]
189
196
new_output = join (0 , * new_components )
197
+
190
198
elif not check_for_block_diag (x ) and check_for_block_diag (y ):
191
199
components = y .owner .inputs
192
200
x_splits = split (
@@ -201,11 +209,14 @@ def check_for_block_diag(x):
201
209
]
202
210
new_output = join (1 , * new_components )
203
211
212
+ # Case 2: Both inputs are BlockDiagonal. Do nothing
204
213
else :
214
+ # TODO: If shapes are statically known and all components have equal shapes, we could rewrite
215
+ # this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)])
205
216
return None
206
217
207
218
copy_stack_trace (node .outputs [0 ], new_output )
208
- return [ new_output ]
219
+ return { dot_node . outputs [ 0 ]: new_output }
209
220
210
221
211
222
@register_canonicalize
0 commit comments