2929 constant ,
3030 expand_dims ,
3131 get_underlying_scalar_constant_value ,
32- join ,
3332 moveaxis ,
3433 ones_like ,
3534 register_infer_shape ,
4140from pytensor .tensor .blockwise import Blockwise
4241from pytensor .tensor .elemwise import CAReduce , DimShuffle , Elemwise
4342from pytensor .tensor .exceptions import NotScalarConstantError
44- from pytensor .tensor .extra_ops import broadcast_arrays
43+ from pytensor .tensor .extra_ops import broadcast_arrays , concat_with_broadcast
4544from pytensor .tensor .math import (
4645 Dot ,
4746 Prod ,
@@ -151,6 +150,7 @@ def local_0_dot_x(fgraph, node):
151150
152151
153152@register_stabilize
153+ @register_specialize
154154@node_rewriter ([Blockwise ])
155155def local_block_diag_dot_to_dot_block_diag (fgraph , node ):
156156 r"""
@@ -174,29 +174,31 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node):
174174 ):
175175 continue
176176
177- op = client .op
177+ [blockdiag_result ] = node .outputs
178+ blockdiag_inputs = node .inputs
178179
179- client_idx = client .inputs . index ( node . outputs [ 0 ])
180+ dot_op = client .op
180181
182+ client_idx = client .inputs .index (blockdiag_result )
181183 other_input = client .inputs [1 - client_idx ]
182- components = node .inputs
183184
184185 split_axis = - 2 if client_idx == 0 else - 1
185186 shape_idx = - 1 if client_idx == 0 else - 2
186187
187188 other_dot_input_split = split (
188189 other_input ,
189- splits_size = [component .shape [shape_idx ] for component in components ],
190- n_splits = len (components ),
190+ splits_size = [component .shape [shape_idx ] for component in blockdiag_inputs ],
191+ n_splits = len (blockdiag_inputs ),
191192 axis = split_axis ,
192193 )
193- new_components = [
194- op (component , other_split )
194+
195+ split_dot_results = [
196+ dot_op (component , other_split )
195197 if client_idx == 0
196- else op (other_split , component )
197- for component , other_split in zip (components , other_dot_input_split )
198+ else dot_op (other_split , component )
199+ for component , other_split in zip (blockdiag_inputs , other_dot_input_split )
198200 ]
199- new_output = join ( split_axis , * new_components )
201+ new_output = concat_with_broadcast ( split_dot_results , dim = split_axis )
200202
201203 copy_stack_trace (node .outputs [0 ], new_output )
202204 return {client .outputs [0 ]: new_output }
0 commit comments