@@ -176,7 +176,7 @@ def local_0_dot_x(fgraph, node):
176
176
@node_rewriter ([Dot ])
177
177
def local_block_diag_dot_to_dot_block_diag (fgraph , node ):
178
178
r"""
179
- Perform the rewrite ``dot(block_diag(A, B), C) -> block_diag (dot(A, C), dot(B, C))``
179
+ Perform the rewrite ``dot(block_diag(A, B), C) -> concat (dot(A, C), dot(B, C))``
180
180
181
181
BlockDiag results in the creation of a matrix of shape ``(n1 * n2, m1 * m2)``. Because dot has complexity
182
182
of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than
@@ -210,25 +210,18 @@ def check_for_block_diag(x):
210
210
new_output = join (0 , * new_components )
211
211
elif not check_for_block_diag (x ) and check_for_block_diag (y ):
212
212
components = y .owner .inputs
213
- new_components = [op (x , component ) for component in components ]
214
- new_output = join (0 , * new_components )
215
-
216
- # Case 2: Both inputs are BlockDiagonal. Here we can proceed only if the static shapes are known and identical. In
217
- # that case, blockdiag(a,b) @ blockdiag(c, d) = blockdiag(a @ c, b @ d), but this is not true in the general case
218
- elif any (shape is None for shape in (* x .type .shape , * y .type .shape )):
219
- return None
220
- elif x .ndim == y .ndim and all (
221
- x_shape == y_shape for x_shape , y_shape in zip (x .type .shape , y .type .shape )
222
- ):
223
- x_components = x .owner .inputs
224
- y_components = y .owner .inputs
213
+ x_splits = split (
214
+ x ,
215
+ splits_size = [component .shape [0 ] for component in components ],
216
+ n_splits = len (components ),
217
+ axis = 1 ,
218
+ )
225
219
226
- if len (x_components ) != len (y_components ):
227
- return None
220
+ new_components = [
221
+ op (x_split , component ) for component , x_split in zip (components , x_splits )
222
+ ]
223
+ new_output = join (1 , * new_components )
228
224
229
- new_output = BlockDiagonal (len (x_components ))(
230
- * [op (x_comp , y_comp ) for x_comp , y_comp in zip (x_components , y_components )]
231
- )
232
225
else :
233
226
return None
234
227
0 commit comments