@@ -170,53 +170,54 @@ def check_for_block_diag(x):
170
170
)
171
171
172
172
# Check that the BlockDiagonal is an input to a Dot node:
173
- clients = list ( get_clients_at_depth (fgraph , node , depth = 1 ))
174
- if not clients or len ( clients ) > 1 or not isinstance (clients [ 0 ] .op , Dot ):
175
- return
173
+ for client in get_clients_at_depth (fgraph , node , depth = 1 ):
174
+ if not isinstance (client .op , Dot ):
175
+ return
176
176
177
- [dot_node ] = clients
178
- op = dot_node .op
179
- x , y = dot_node .inputs
177
+ op = client .op
178
+ x , y = client .inputs
180
179
181
- if not (check_for_block_diag (x ) or check_for_block_diag (y )):
182
- return None
180
+ if not (check_for_block_diag (x ) or check_for_block_diag (y )):
181
+ return None
183
182
184
- # Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the
185
- # non-block diagonal, and return a new block diagonal
186
- if check_for_block_diag (x ) and not check_for_block_diag (y ):
187
- components = x .owner .inputs
188
- y_splits = split (
189
- y ,
190
- splits_size = [component .shape [- 1 ] for component in components ],
191
- n_splits = len (components ),
192
- )
193
- new_components = [
194
- op (component , y_split ) for component , y_split in zip (components , y_splits )
195
- ]
196
- new_output = join (0 , * new_components )
197
-
198
- elif not check_for_block_diag (x ) and check_for_block_diag (y ):
199
- components = y .owner .inputs
200
- x_splits = split (
201
- x ,
202
- splits_size = [component .shape [0 ] for component in components ],
203
- n_splits = len (components ),
204
- axis = 1 ,
205
- )
183
+ # Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the
184
+ # non-block diagonal, and return a new block diagonal
185
+ if check_for_block_diag (x ) and not check_for_block_diag (y ):
186
+ components = x .owner .inputs
187
+ y_splits = split (
188
+ y ,
189
+ splits_size = [component .shape [- 1 ] for component in components ],
190
+ n_splits = len (components ),
191
+ )
192
+ new_components = [
193
+ op (component , y_split )
194
+ for component , y_split in zip (components , y_splits )
195
+ ]
196
+ new_output = join (0 , * new_components )
197
+
198
+ elif not check_for_block_diag (x ) and check_for_block_diag (y ):
199
+ components = y .owner .inputs
200
+ x_splits = split (
201
+ x ,
202
+ splits_size = [component .shape [0 ] for component in components ],
203
+ n_splits = len (components ),
204
+ axis = 1 ,
205
+ )
206
206
207
- new_components = [
208
- op (x_split , component ) for component , x_split in zip (components , x_splits )
209
- ]
210
- new_output = join (1 , * new_components )
207
+ new_components = [
208
+ op (x_split , component )
209
+ for component , x_split in zip (components , x_splits )
210
+ ]
211
+ new_output = join (1 , * new_components )
211
212
212
- # Case 2: Both inputs are BlockDiagonal. Do nothing
213
- else :
214
- # TODO: If shapes are statically known and all components have equal shapes, we could rewrite
215
- # this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)])
216
- return None
213
+ # Case 2: Both inputs are BlockDiagonal. Do nothing
214
+ else :
215
+ # TODO: If shapes are statically known and all components have equal shapes, we could rewrite
216
+ # this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)])
217
+ return None
217
218
218
- copy_stack_trace (node .outputs [0 ], new_output )
219
- return {dot_node .outputs [0 ]: new_output }
219
+ copy_stack_trace (node .outputs [0 ], new_output )
220
+ return {client .outputs [0 ]: new_output }
220
221
221
222
222
223
@register_canonicalize
0 commit comments