11from collections .abc import Iterable , Sequence
2+ from typing import cast
23
34import numpy as np
45
56from pytensor import Variable
6- from pytensor .graph import Constant , node_rewriter
7- from pytensor .graph .rewriting .basic import copy_stack_trace
7+ from pytensor .graph import Constant , FunctionGraph , node_rewriter
8+ from pytensor .graph .rewriting .basic import NodeRewriter , copy_stack_trace
89from pytensor .npy_2_compat import normalize_axis_index , normalize_axis_tuple
910from pytensor .scalar import basic as ps
1011from pytensor .tensor .basic import (
1112 Alloc ,
13+ Join ,
1214 MakeVector ,
1315 alloc ,
1416 as_tensor ,
1517 expand_dims ,
1618 get_underlying_scalar_constant_value ,
19+ join ,
1720 register_infer_shape ,
1821)
1922from pytensor .tensor .elemwise import CAReduce , DimShuffle , Elemwise
4447)
4548from pytensor .tensor .type import TensorType
4649from pytensor .tensor .type_other import SliceType
50+ from pytensor .tensor .variable import TensorVariable
4751
4852
4953def _dims_dropped_by_basic_index (idxs : Sequence [slice | int ]) -> tuple [int , ...]:
@@ -66,6 +70,41 @@ def _axis_is_indexed_by_basic_index(
6670 return any (ax < len (idxs ) and not is_full_slice (idxs [ax ]) for ax in axis )
6771
6872
73+ def _lift_subtensor_non_axis (
74+ local_subtensor_lift_rewrite : NodeRewriter ,
75+ fgraph : FunctionGraph ,
76+ variable : TensorVariable ,
77+ idx_tuple : tuple [int | slice ],
78+ axis : int ,
79+ old_subtensor_variable : TensorVariable ,
80+ ) -> None | list [TensorVariable ]:
81+ # Apply generic subtensor lift rewrite along "non-axis" dimensions
82+ real_indices = [idx for idx in idx_tuple if not is_full_slice (idx )]
83+ if len (real_indices ) > 1 and variable .type .ndim > 1 :
84+ # Split the subtensor
85+ idx_to_keep = idx_tuple [axis ]
86+ idxs_to_lift = (* idx_tuple [:axis ], slice (None ), * idx_tuple [axis + 1 :])
87+
88+ # Lift the non-axis indexes by calling the rewrite itself
89+ indexed_variable = variable [idxs_to_lift ]
90+ [indexed_variable ] = cast (
91+ list [TensorVariable ],
92+ local_subtensor_lift_rewrite .transform (fgraph , indexed_variable .owner ),
93+ )
94+ copy_stack_trace ([old_subtensor_variable , indexed_variable ], indexed_variable )
95+
96+ # Then reintroduce the axis index
97+ ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index (idx_tuple , axis )
98+ new_axis = axis - ndim_reduced_left
99+ idxs_to_keep = (* (slice (None ),) * new_axis , idx_to_keep )
100+ new_out = indexed_variable [idxs_to_keep ]
101+ copy_stack_trace (old_subtensor_variable , new_out )
102+ return [new_out ]
103+
104+ else :
105+ return None
106+
107+
69108@register_canonicalize
70109@register_stabilize
71110@register_specialize
@@ -297,29 +336,14 @@ def local_subtensor_of_softmax(fgraph, node):
297336 if _axis_is_indexed_by_basic_index (idx_tuple , axis ):
298337 # If there are more dimensions being indexed, we can split them
299338 # And lift the non-axis indexes while keeping the axis index
300- real_indices = [idx for idx in idx_tuple if not is_full_slice (idx )]
301- if len (real_indices ) > 1 and sm .type .ndim > 1 :
302- # Split the subtensor
303- idx_to_keep = idx_tuple [axis ]
304- idxs_to_lift = (* idx_tuple [:axis ], slice (None ), * idx_tuple [axis + 1 :])
305-
306- # Lift the non-axis indexes by calling the rewrite itself
307- opt_sm = sm [idxs_to_lift ]
308- [opt_sm ] = local_subtensor_of_softmax .transform (fgraph , opt_sm .owner )
309- copy_stack_trace ([old_out , sm ], opt_sm )
310-
311- # Then reintroduce the axis index
312- ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index (
313- idx_tuple , axis
314- )
315- new_axis = axis - ndim_reduced_left
316- idxs_to_keep = (* (slice (None ),) * new_axis , idx_to_keep )
317- new_out = opt_sm [idxs_to_keep ]
318- copy_stack_trace (old_out , new_out )
319- return [new_out ]
320-
321- else :
322- return None
339+ return _lift_subtensor_non_axis (
340+ local_subtensor_lift_rewrite = local_subtensor_of_softmax ,
341+ fgraph = fgraph ,
342+ variable = sm ,
343+ idx_tuple = idx_tuple ,
344+ axis = axis ,
345+ old_subtensor_variable = old_out ,
346+ )
323347
324348 # Index input to softmax
325349 x_sub = x [idx_tuple ]
@@ -646,6 +670,52 @@ def local_subtensor_make_vector(fgraph, node):
646670 pass
647671
648672
673+ @register_canonicalize
674+ @register_specialize
675+ @node_rewriter ([Subtensor ])
676+ def local_subtensor_of_join (fgraph , node ):
677+ """Lift a Subtensor through a Join.
678+
679+ join(axis=1, x, y)[0] -> join(axis=0, x[0], y[0])
680+ join(axis=1, x, y)[:, 0, -1] -> join(axis=1, x[:, :, -1], y[:, :, -1])[:, 0]
681+
682+ """
683+ join_var , * idx = node .inputs
684+
685+ if not (join_var .owner and isinstance (join_var .owner .op , Join )):
686+ return None
687+
688+ if len (fgraph .clients [join_var ]) > 1 :
689+ # Join involves a full_copy, so we don't want to do it twice
690+ return None
691+
692+ join_axis , * join_components = join_var .owner .inputs
693+
694+ # Rewrite only works when the join axis is a constant along a non-indexed dimension
695+ if not isinstance (join_axis , Constant ):
696+ return None
697+
698+ [old_out ] = node .outputs
699+ axis = normalize_axis_index (join_axis .data , join_components [0 ].type .ndim )
700+ idx_tuple = indices_from_subtensor (idx , node .op .idx_list )
701+ if _axis_is_indexed_by_basic_index (idx_tuple , axis ):
702+ return _lift_subtensor_non_axis (
703+ local_subtensor_lift_rewrite = local_subtensor_of_join ,
704+ fgraph = fgraph ,
705+ variable = join_var ,
706+ idx_tuple = idx_tuple ,
707+ axis = axis ,
708+ old_subtensor_variable = old_out ,
709+ )
710+
711+ # Lift index to the Join components
712+ indexed_components = [component [idx_tuple ] for component in join_components ]
713+ new_axis = axis - _ndim_dropped_left_of_axis_by_basic_index (idx_tuple , axis )
714+ out = join (new_axis , * indexed_components )
715+
716+ return [out ]
717+
718+
649719@register_specialize
650720@register_canonicalize
651721@node_rewriter ([Subtensor ])
0 commit comments