1414 register_stabilize ,
1515)
1616from pytensor .tensor .shape import Reshape
17- from pytensor .tensor .subtensor import AdvancedIncSubtensor , AdvancedSubtensor , Subtensor
17+ from pytensor .tensor .subtensor import (
18+ AdvancedIncSubtensor ,
19+ AdvancedSubtensor ,
20+ Subtensor ,
21+ indices_from_subtensor ,
22+ )
1823
1924
2025@node_rewriter ([Blockwise ])
@@ -216,9 +221,9 @@ def local_blockwise_reshape(fgraph, node):
216221
217222 Reshape is tricky to vectorize eagerly, because a graph like
218223 `x.reshape([x.shape[0] * x.shape[1], -1])` has many operations
219- that must be vectorized before we arrize at the reshape operation.
224+ that must be vectorized before we arrive at the reshape operation.
220225
221- For the square Reshape case, we must wait for all the intemediate
226+ For the square Reshape case, we must wait for all the intremediate
222227 operations to be lifted as Allocs
223228 """
224229 if not isinstance (node .op .core_op , Reshape ):
@@ -234,6 +239,26 @@ def local_blockwise_reshape(fgraph, node):
234239 return [new_out ]
235240
236241
242+ @register_stabilize
243+ @register_specialize
244+ @node_rewriter ([Blockwise ])
245+ def local_blockwise_of_subtensor (fgraph , node ):
246+ """Rewrite Blockwise of Subtensor, where the only batch dimensions are the inputs."""
247+ if not isinstance (node .op .core_op , Subtensor ):
248+ return
249+
250+ x , * idxs = node .inputs
251+ if not all (all (idx .type .broadcastable ) for idx in idxs ):
252+ return
253+
254+ core_idxs = indices_from_subtensor (
255+ [idx .squeeze () for idx in idxs ], node .op .core_op .idx_list
256+ )
257+ # Add empty slices for the batch dims
258+ none_slices = (slice (None ),) * node .op .batch_ndim (node )
259+ return [x [(* none_slices , * core_idxs )]]
260+
261+
237262@node_rewriter (tracks = [Blockwise ], inplace = True )
238263def blockwise_inplace (fgraph , node ):
239264 blockwise_op = node .op
0 commit comments