@@ -155,7 +155,7 @@ def local_0_dot_x(fgraph, node):
155
155
@node_rewriter ([Dot ])
156
156
def local_block_diag_dot_to_dot_block_diag (fgraph , node ):
157
157
r"""
158
- Perform the rewrite ``dot(block_diag(A, B), C) -> block_diag (dot(A, C), dot(B, C))``
158
+ Perform the rewrite ``dot(block_diag(A, B), C) -> concat (dot(A, C), dot(B, C))``
159
159
160
160
BlockDiag results in the creation of a matrix of shape ``(n1 * n2, m1 * m2)``. Because dot has complexity
161
161
of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than
@@ -189,25 +189,18 @@ def check_for_block_diag(x):
189
189
new_output = join (0 , * new_components )
190
190
elif not check_for_block_diag (x ) and check_for_block_diag (y ):
191
191
components = y .owner .inputs
192
- new_components = [op (x , component ) for component in components ]
193
- new_output = join (0 , * new_components )
194
-
195
- # Case 2: Both inputs are BlockDiagonal. Here we can proceed only if the static shapes are known and identical. In
196
- # that case, blockdiag(a,b) @ blockdiag(c, d) = blockdiag(a @ c, b @ d), but this is not true in the general case
197
- elif any (shape is None for shape in (* x .type .shape , * y .type .shape )):
198
- return None
199
- elif x .ndim == y .ndim and all (
200
- x_shape == y_shape for x_shape , y_shape in zip (x .type .shape , y .type .shape )
201
- ):
202
- x_components = x .owner .inputs
203
- y_components = y .owner .inputs
192
+ x_splits = split (
193
+ x ,
194
+ splits_size = [component .shape [0 ] for component in components ],
195
+ n_splits = len (components ),
196
+ axis = 1 ,
197
+ )
204
198
205
- if len (x_components ) != len (y_components ):
206
- return None
199
+ new_components = [
200
+ op (x_split , component ) for component , x_split in zip (components , x_splits )
201
+ ]
202
+ new_output = join (1 , * new_components )
207
203
208
- new_output = BlockDiagonal (len (x_components ))(
209
- * [op (x_comp , y_comp ) for x_comp , y_comp in zip (x_components , y_components )]
210
- )
211
204
else :
212
205
return None
213
206
0 commit comments