@@ -191,53 +191,54 @@ def check_for_block_diag(x):
191191 )
192192
193193 # Check that the BlockDiagonal is an input to a Dot node:
194- clients = list ( get_clients_at_depth (fgraph , node , depth = 1 ))
195- if not clients or len ( clients ) > 1 or not isinstance (clients [ 0 ] .op , Dot ):
196- return
194+ for client in get_clients_at_depth (fgraph , node , depth = 1 ):
195+ if not isinstance (client .op , Dot ):
196+ return
197197
198- [dot_node ] = clients
199- op = dot_node .op
200- x , y = dot_node .inputs
198+ op = client .op
199+ x , y = client .inputs
201200
202- if not (check_for_block_diag (x ) or check_for_block_diag (y )):
203- return None
201+ if not (check_for_block_diag (x ) or check_for_block_diag (y )):
202+ return None
204203
205- # Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the
206- # non-block diagonal, and return a new block diagonal
207- if check_for_block_diag (x ) and not check_for_block_diag (y ):
208- components = x .owner .inputs
209- y_splits = split (
210- y ,
211- splits_size = [component .shape [- 1 ] for component in components ],
212- n_splits = len (components ),
213- )
214- new_components = [
215- op (component , y_split ) for component , y_split in zip (components , y_splits )
216- ]
217- new_output = join (0 , * new_components )
218-
219- elif not check_for_block_diag (x ) and check_for_block_diag (y ):
220- components = y .owner .inputs
221- x_splits = split (
222- x ,
223- splits_size = [component .shape [0 ] for component in components ],
224- n_splits = len (components ),
225- axis = 1 ,
226- )
204+ # Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the
205+ # non-block diagonal, and return a new block diagonal
206+ if check_for_block_diag (x ) and not check_for_block_diag (y ):
207+ components = x .owner .inputs
208+ y_splits = split (
209+ y ,
210+ splits_size = [component .shape [- 1 ] for component in components ],
211+ n_splits = len (components ),
212+ )
213+ new_components = [
214+ op (component , y_split )
215+ for component , y_split in zip (components , y_splits )
216+ ]
217+ new_output = join (0 , * new_components )
218+
219+ elif not check_for_block_diag (x ) and check_for_block_diag (y ):
220+ components = y .owner .inputs
221+ x_splits = split (
222+ x ,
223+ splits_size = [component .shape [0 ] for component in components ],
224+ n_splits = len (components ),
225+ axis = 1 ,
226+ )
227227
228- new_components = [
229- op (x_split , component ) for component , x_split in zip (components , x_splits )
230- ]
231- new_output = join (1 , * new_components )
228+ new_components = [
229+ op (x_split , component )
230+ for component , x_split in zip (components , x_splits )
231+ ]
232+ new_output = join (1 , * new_components )
232233
233- # Case 2: Both inputs are BlockDiagonal. Do nothing
234- else :
235- # TODO: If shapes are statically known and all components have equal shapes, we could rewrite
236- # this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)])
237- return None
234+ # Case 2: Both inputs are BlockDiagonal. Do nothing
235+ else :
236+ # TODO: If shapes are statically known and all components have equal shapes, we could rewrite
237+ # this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)])
238+ return None
238239
239- copy_stack_trace (node .outputs [0 ], new_output )
240- return {dot_node .outputs [0 ]: new_output }
240+ copy_stack_trace (node .outputs [0 ], new_output )
241+ return {client .outputs [0 ]: new_output }
241242
242243
243244@register_canonicalize
0 commit comments