11from collections .abc import Iterable , Sequence
2+ from typing import cast
23
34import numpy as np
45
56from pytensor import Variable
6- from pytensor .graph import Constant , node_rewriter
7- from pytensor .graph .rewriting .basic import copy_stack_trace
7+ from pytensor .graph import Constant , FunctionGraph , node_rewriter
8+ from pytensor .graph .rewriting .basic import NodeRewriter , copy_stack_trace
89from pytensor .npy_2_compat import normalize_axis_index , normalize_axis_tuple
910from pytensor .scalar import basic as ps
1011from pytensor .tensor .basic import (
1112 Alloc ,
13+ Join ,
1214 MakeVector ,
1315 alloc ,
1416 as_tensor ,
1517 expand_dims ,
1618 get_underlying_scalar_constant_value ,
19+ join ,
1720 register_infer_shape ,
1821)
1922from pytensor .tensor .elemwise import CAReduce , DimShuffle , Elemwise
4649)
4750from pytensor .tensor .type import TensorType
4851from pytensor .tensor .type_other import SliceType
52+ from pytensor .tensor .variable import TensorVariable
4953
5054
5155def _dims_dropped_by_basic_index (idxs : Sequence [slice | int ]) -> tuple [int , ...]:
@@ -68,6 +72,41 @@ def _axis_is_indexed_by_basic_index(
6872 return any (ax < len (idxs ) and not is_full_slice (idxs [ax ]) for ax in axis )
6973
7074
75+ def _lift_subtensor_non_axis (
76+ local_subtensor_lift_rewrite : NodeRewriter ,
77+ fgraph : FunctionGraph ,
78+ variable : TensorVariable ,
79+ idx_tuple : tuple [int | slice ],
80+ axis : int ,
81+ old_subtensor_variable : TensorVariable ,
82+ ) -> None | list [TensorVariable ]:
83+ # Apply generic subtensor lift rewrite along "non-axis" dimensions
84+ real_indices = [idx for idx in idx_tuple if not is_full_slice (idx )]
85+ if len (real_indices ) > 1 and variable .type .ndim > 1 :
86+ # Split the subtensor
87+ idx_to_keep = idx_tuple [axis ]
88+ idxs_to_lift = (* idx_tuple [:axis ], slice (None ), * idx_tuple [axis + 1 :])
89+
90+ # Lift the non-axis indexes by calling the rewrite itself
91+ indexed_variable = variable [idxs_to_lift ]
92+ [indexed_variable ] = cast (
93+ list [TensorVariable ],
94+ local_subtensor_lift_rewrite .transform (fgraph , indexed_variable .owner ),
95+ )
96+ copy_stack_trace ([old_subtensor_variable , indexed_variable ], indexed_variable )
97+
98+ # Then reintroduce the axis index
99+ ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index (idx_tuple , axis )
100+ new_axis = axis - ndim_reduced_left
101+ idxs_to_keep = (* (slice (None ),) * new_axis , idx_to_keep )
102+ new_out = indexed_variable [idxs_to_keep ]
103+ copy_stack_trace (old_subtensor_variable , new_out )
104+ return [new_out ]
105+
106+ else :
107+ return None
108+
109+
71110@register_canonicalize
72111@register_stabilize
73112@register_specialize
@@ -299,29 +338,14 @@ def local_subtensor_of_softmax(fgraph, node):
299338 if _axis_is_indexed_by_basic_index (idx_tuple , axis ):
300339 # If there are more dimensions being indexed, we can split them
301340 # And lift the non-axis indexes while keeping the axis index
302- real_indices = [idx for idx in idx_tuple if not is_full_slice (idx )]
303- if len (real_indices ) > 1 and sm .type .ndim > 1 :
304- # Split the subtensor
305- idx_to_keep = idx_tuple [axis ]
306- idxs_to_lift = (* idx_tuple [:axis ], slice (None ), * idx_tuple [axis + 1 :])
307-
308- # Lift the non-axis indexes by calling the rewrite itself
309- opt_sm = sm [idxs_to_lift ]
310- [opt_sm ] = local_subtensor_of_softmax .transform (fgraph , opt_sm .owner )
311- copy_stack_trace ([old_out , sm ], opt_sm )
312-
313- # Then reintroduce the axis index
314- ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index (
315- idx_tuple , axis
316- )
317- new_axis = axis - ndim_reduced_left
318- idxs_to_keep = (* (slice (None ),) * new_axis , idx_to_keep )
319- new_out = opt_sm [idxs_to_keep ]
320- copy_stack_trace (old_out , new_out )
321- return [new_out ]
322-
323- else :
324- return None
341+ return _lift_subtensor_non_axis (
342+ local_subtensor_lift_rewrite = local_subtensor_of_softmax ,
343+ fgraph = fgraph ,
344+ variable = sm ,
345+ idx_tuple = idx_tuple ,
346+ axis = axis ,
347+ old_subtensor_variable = old_out ,
348+ )
325349
326350 # Index input to softmax
327351 x_sub = x [idx_tuple ]
@@ -693,6 +717,52 @@ def local_subtensor_make_vector(fgraph, node):
693717 pass
694718
695719
720+ @register_canonicalize
721+ @register_specialize
722+ @node_rewriter ([Subtensor ])
723+ def local_subtensor_of_join (fgraph , node ):
724+ """Lift a Subtensor through a Join.
725+
726+ join(axis=1, x, y)[0] -> join(axis=0, x[0], y[0])
727+ join(axis=1, x, y)[:, 0, -1] -> join(axis=1, x[:, :, -1], y[:, :, -1])[:, 0]
728+
729+ """
730+ join_var , * idx = node .inputs
731+
732+ if not (join_var .owner and isinstance (join_var .owner .op , Join )):
733+ return None
734+
735+ if len (fgraph .clients [join_var ]) > 1 :
736+ # Join involves a full_copy, so we don't want to do it twice
737+ return None
738+
739+ join_axis , * join_components = join_var .owner .inputs
740+
741+ # Rewrite only works when the join axis is a constant along a non-indexed dimension
742+ if not isinstance (join_axis , Constant ):
743+ return None
744+
745+ [old_out ] = node .outputs
746+ axis = normalize_axis_index (join_axis .data , join_components [0 ].type .ndim )
747+ idx_tuple = indices_from_subtensor (idx , node .op .idx_list )
748+ if _axis_is_indexed_by_basic_index (idx_tuple , axis ):
749+ return _lift_subtensor_non_axis (
750+ local_subtensor_lift_rewrite = local_subtensor_of_join ,
751+ fgraph = fgraph ,
752+ variable = join_var ,
753+ idx_tuple = idx_tuple ,
754+ axis = axis ,
755+ old_subtensor_variable = old_out ,
756+ )
757+
758+ # Lift index to the Join components
759+ indexed_components = [component [idx_tuple ] for component in join_components ]
760+ new_axis = axis - _ndim_dropped_left_of_axis_by_basic_index (idx_tuple , axis )
761+ out = join (new_axis , * indexed_components )
762+
763+ return [out ]
764+
765+
696766@register_specialize
697767@register_canonicalize
698768@node_rewriter ([Subtensor ])
0 commit comments