11from collections .abc import Iterable , Sequence
22
33import numpy as np
4- from numpy .core .numeric import normalize_axis_tuple # type: ignore
4+ from numpy .core .numeric import ( # type: ignore
5+ normalize_axis_index ,
6+ normalize_axis_tuple ,
7+ )
58
69from pytensor import Variable
710from pytensor .graph import Constant , node_rewriter
3437 specify_shape ,
3538 unbroadcast ,
3639)
40+ from pytensor .tensor .special import Softmax , softmax
3741from pytensor .tensor .subtensor import (
3842 AdvancedSubtensor1 ,
3943 Subtensor ,
@@ -53,6 +57,20 @@ def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]
5357 return tuple (i for i , idx in enumerate (idxs ) if not isinstance (idx , slice ))
5458
5559
60+ def _ndim_dropped_left_of_axis_by_basic_index (
61+ idxs : Sequence [slice | int ], axis : int
62+ ) -> int :
63+ return len (_dims_dropped_by_basic_index (idxs [:axis ]))
64+
65+
66+ def _axis_is_indexed_by_basic_index (
67+ idxs : Sequence [slice | int ], axis : int | Sequence [int ]
68+ ) -> bool :
69+ if isinstance (axis , int ):
70+ axis = (axis ,)
71+ return any (ax < len (idxs ) and not is_full_slice (idxs [ax ]) for ax in axis )
72+
73+
5674@register_canonicalize
5775@register_stabilize
5876@register_specialize
@@ -243,6 +261,84 @@ def local_subtensor_of_reduce(fgraph, node):
243261 return [out ]
244262
245263
264+ @register_canonicalize
265+ @register_specialize
266+ @node_rewriter ([Subtensor ])
267+ def local_subtensor_of_softmax (fgraph , node ):
268+ """Lift a Subtensor through a Softmax.
269+
270+ softmax(x, axis=1)[0] -> softmax(x[0], axis=0)
271+ softmax(x, axis=1)[:, :, 0] -> softmax(x[:, :, 0], axis=1)
272+
273+ If part of the indexing acts on the axis of reduction, we split it
274+ softmax(x, axis=1)[:, 0, 1:] -> softmax(x[:, :, 1:], axis=1)[0]
275+
276+ """
277+ sm , * idx = node .inputs
278+
279+ if not (sm .owner and isinstance (sm .owner .op , Softmax )):
280+ return None
281+
282+ if len (fgraph .clients [sm ]) > 1 :
283+ return None
284+
285+ [x ] = sm .owner .inputs
286+ axis = sm .owner .op .axis
287+
288+ if axis is None :
289+ if x .type .ndim == 1 :
290+ axis = 0
291+ else :
292+ # All dimensions are mixed, we can't lift the subtensor
293+ return None
294+ else :
295+ # Softmax currently only allows None or a single integer axis
296+ # Unlike CAReduce it does not normalize negative indices
297+ axis = normalize_axis_index (axis , sm .ndim )
298+
299+ [old_out ] = node .outputs
300+ idx_tuple = indices_from_subtensor (idx , node .op .idx_list )
301+
302+ if _axis_is_indexed_by_basic_index (idx_tuple , axis ):
303+ # If there are more dimensions being indexed, we can split them
304+ # And lift the non-axis indexes while keeping the axis index
305+ real_indices = [idx for idx in idx_tuple if not is_full_slice (idx )]
306+ if len (real_indices ) > 1 and sm .type .ndim > 1 :
307+ # Split the subtensor
308+ idx_to_keep = idx_tuple [axis ]
309+ idxs_to_lift = (* idx_tuple [:axis ], slice (None ), * idx_tuple [axis + 1 :])
310+
311+ # Lift the non-axis indexes by calling the rewrite itself
312+ opt_sm = sm [idxs_to_lift ]
313+ [opt_sm ] = local_subtensor_of_softmax .transform (fgraph , opt_sm .owner )
314+ copy_stack_trace ([old_out , sm ], opt_sm )
315+
316+ # Then reintroduce the axis index
317+ ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index (
318+ idx_tuple , axis
319+ )
320+ new_axis = axis - ndim_reduced_left
321+ idxs_to_keep = (* (slice (None ),) * new_axis , idx_to_keep )
322+ new_out = opt_sm [idxs_to_keep ]
323+ copy_stack_trace (old_out , new_out )
324+ return [new_out ]
325+
326+ else :
327+ return None
328+
329+ # Index input to softmax
330+ x_sub = x [idx_tuple ]
331+
332+ # Adjust axis of reduction when indexing drops dimensions (integer indexing as apposed to slice indexing)
333+ axis -= len (
334+ [idx_item for idx_item in idx_tuple [:axis ] if not isinstance (idx_item , slice )]
335+ )
336+
337+ out = softmax (x_sub , axis = axis )
338+ copy_stack_trace (old_out , out )
339+ return [out ]
340+
341+
246342@register_canonicalize ("shape_unsafe" )
247343@register_specialize ("shape_unsafe" )
248344@node_rewriter ([Subtensor ])
0 commit comments