55from pytensor import Variable
66from pytensor .graph import Constant , node_rewriter
77from pytensor .graph .rewriting .basic import copy_stack_trace
8- from pytensor .npy_2_compat import normalize_axis_tuple
8+ from pytensor .npy_2_compat import normalize_axis_index , normalize_axis_tuple
99from pytensor .scalar import basic as ps
1010from pytensor .tensor .basic import (
1111 Alloc ,
3434 specify_shape ,
3535 unbroadcast ,
3636)
37+ from pytensor .tensor .special import Softmax , softmax
3738from pytensor .tensor .subtensor import (
3839 AdvancedSubtensor1 ,
3940 Subtensor ,
@@ -53,6 +54,20 @@ def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]
5354 return tuple (i for i , idx in enumerate (idxs ) if not isinstance (idx , slice ))
5455
5556
57+ def _ndim_dropped_left_of_axis_by_basic_index (
58+ idxs : Sequence [slice | int ], axis : int
59+ ) -> int :
60+ return len (_dims_dropped_by_basic_index (idxs [:axis ]))
61+
62+
63+ def _axis_is_indexed_by_basic_index (
64+ idxs : Sequence [slice | int ], axis : int | Sequence [int ]
65+ ) -> bool :
66+ if isinstance (axis , int ):
67+ axis = (axis ,)
68+ return any (ax < len (idxs ) and not is_full_slice (idxs [ax ]) for ax in axis )
69+
70+
5671@register_canonicalize
5772@register_stabilize
5873@register_specialize
@@ -243,6 +258,84 @@ def local_subtensor_of_reduce(fgraph, node):
243258 return [out ]
244259
245260
261+ @register_canonicalize
262+ @register_specialize
263+ @node_rewriter ([Subtensor ])
264+ def local_subtensor_of_softmax (fgraph , node ):
265+ """Lift a Subtensor through a Softmax.
266+
267+ softmax(x, axis=1)[0] -> softmax(x[0], axis=0)
268+ softmax(x, axis=1)[:, :, 0] -> softmax(x[:, :, 0], axis=1)
269+
270+ If part of the indexing acts on the axis of reduction, we split it
271+ softmax(x, axis=1)[:, 0, 1:] -> softmax(x[:, :, 1:], axis=1)[0]
272+
273+ """
274+ sm , * idx = node .inputs
275+
276+ if not (sm .owner and isinstance (sm .owner .op , Softmax )):
277+ return None
278+
279+ if len (fgraph .clients [sm ]) > 1 :
280+ return None
281+
282+ [x ] = sm .owner .inputs
283+ axis = sm .owner .op .axis
284+
285+ if axis is None :
286+ if x .type .ndim == 1 :
287+ axis = 0
288+ else :
289+ # All dimensions are mixed, we can't lift the subtensor
290+ return None
291+ else :
292+ # Softmax currently only allows None or a single integer axis
293+ # Unlike CAReduce it does not normalize negative indices
294+ axis = normalize_axis_index (axis , sm .ndim )
295+
296+ [old_out ] = node .outputs
297+ idx_tuple = indices_from_subtensor (idx , node .op .idx_list )
298+
299+ if _axis_is_indexed_by_basic_index (idx_tuple , axis ):
300+ # If there are more dimensions being indexed, we can split them
301+ # And lift the non-axis indexes while keeping the axis index
302+ real_indices = [idx for idx in idx_tuple if not is_full_slice (idx )]
303+ if len (real_indices ) > 1 and sm .type .ndim > 1 :
304+ # Split the subtensor
305+ idx_to_keep = idx_tuple [axis ]
306+ idxs_to_lift = (* idx_tuple [:axis ], slice (None ), * idx_tuple [axis + 1 :])
307+
308+ # Lift the non-axis indexes by calling the rewrite itself
309+ opt_sm = sm [idxs_to_lift ]
310+ [opt_sm ] = local_subtensor_of_softmax .transform (fgraph , opt_sm .owner )
311+ copy_stack_trace ([old_out , sm ], opt_sm )
312+
313+ # Then reintroduce the axis index
314+ ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index (
315+ idx_tuple , axis
316+ )
317+ new_axis = axis - ndim_reduced_left
318+ idxs_to_keep = (* (slice (None ),) * new_axis , idx_to_keep )
319+ new_out = opt_sm [idxs_to_keep ]
320+ copy_stack_trace (old_out , new_out )
321+ return [new_out ]
322+
323+ else :
324+ return None
325+
326+ # Index input to softmax
327+ x_sub = x [idx_tuple ]
328+
329+ # Adjust axis of reduction when indexing drops dimensions (integer indexing as apposed to slice indexing)
330+ axis -= len (
331+ [idx_item for idx_item in idx_tuple [:axis ] if not isinstance (idx_item , slice )]
332+ )
333+
334+ out = softmax (x_sub , axis = axis )
335+ copy_stack_trace (old_out , out )
336+ return [out ]
337+
338+
246339@register_canonicalize ("shape_unsafe" )
247340@register_specialize ("shape_unsafe" )
248341@node_rewriter ([Subtensor ])
0 commit comments