11from collections .abc import Iterable , Sequence
2+ from typing import cast
23
34import numpy as np
45from numpy .core .numeric import ( # type: ignore
78)
89
910from pytensor import Variable
10- from pytensor .graph import Constant , node_rewriter
11- from pytensor .graph .rewriting .basic import copy_stack_trace
11+ from pytensor .graph import Constant , FunctionGraph , node_rewriter
12+ from pytensor .graph .rewriting .basic import NodeRewriter , copy_stack_trace
1213from pytensor .scalar import basic as ps
1314from pytensor .tensor .basic import (
1415 Alloc ,
16+ Join ,
1517 MakeVector ,
1618 alloc ,
1719 as_tensor ,
1820 expand_dims ,
1921 get_underlying_scalar_constant_value ,
22+ join ,
2023 register_infer_shape ,
2124)
2225from pytensor .tensor .elemwise import CAReduce , DimShuffle , Elemwise
4952)
5053from pytensor .tensor .type import TensorType
5154from pytensor .tensor .type_other import SliceType
55+ from pytensor .tensor .variable import TensorVariable
5256
5357
5458def _dims_dropped_by_basic_index (idxs : Sequence [slice | int ]) -> tuple [int , ...]:
@@ -71,6 +75,41 @@ def _axis_is_indexed_by_basic_index(
7175 return any (ax < len (idxs ) and not is_full_slice (idxs [ax ]) for ax in axis )
7276
7377
78+ def _lift_subtensor_non_axis (
79+ local_subtensor_lift_rewrite : NodeRewriter ,
80+ fgraph : FunctionGraph ,
81+ variable : TensorVariable ,
82+ idx_tuple : tuple [int | slice ],
83+ axis : int ,
84+ old_subtensor_variable : TensorVariable ,
85+ ) -> None | list [TensorVariable ]:
86+ # Apply generic subtensor lift rewrite along "non-axis" dimensions
87+ real_indices = [idx for idx in idx_tuple if not is_full_slice (idx )]
88+ if len (real_indices ) > 1 and variable .type .ndim > 1 :
89+ # Split the subtensor
90+ idx_to_keep = idx_tuple [axis ]
91+ idxs_to_lift = (* idx_tuple [:axis ], slice (None ), * idx_tuple [axis + 1 :])
92+
93+ # Lift the non-axis indexes by calling the rewrite itself
94+ indexed_variable = variable [idxs_to_lift ]
95+ [indexed_variable ] = cast (
96+ list [TensorVariable ],
97+ local_subtensor_lift_rewrite .transform (fgraph , indexed_variable .owner ),
98+ )
99+ copy_stack_trace ([old_subtensor_variable , indexed_variable ], indexed_variable )
100+
101+ # Then reintroduce the axis index
102+ ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index (idx_tuple , axis )
103+ new_axis = axis - ndim_reduced_left
104+ idxs_to_keep = (* (slice (None ),) * new_axis , idx_to_keep )
105+ new_out = indexed_variable [idxs_to_keep ]
106+ copy_stack_trace (old_subtensor_variable , new_out )
107+ return [new_out ]
108+
109+ else :
110+ return None
111+
112+
74113@register_canonicalize
75114@register_stabilize
76115@register_specialize
@@ -302,29 +341,14 @@ def local_subtensor_of_softmax(fgraph, node):
302341 if _axis_is_indexed_by_basic_index (idx_tuple , axis ):
303342 # If there are more dimensions being indexed, we can split them
304343 # And lift the non-axis indexes while keeping the axis index
305- real_indices = [idx for idx in idx_tuple if not is_full_slice (idx )]
306- if len (real_indices ) > 1 and sm .type .ndim > 1 :
307- # Split the subtensor
308- idx_to_keep = idx_tuple [axis ]
309- idxs_to_lift = (* idx_tuple [:axis ], slice (None ), * idx_tuple [axis + 1 :])
310-
311- # Lift the non-axis indexes by calling the rewrite itself
312- opt_sm = sm [idxs_to_lift ]
313- [opt_sm ] = local_subtensor_of_softmax .transform (fgraph , opt_sm .owner )
314- copy_stack_trace ([old_out , sm ], opt_sm )
315-
316- # Then reintroduce the axis index
317- ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index (
318- idx_tuple , axis
319- )
320- new_axis = axis - ndim_reduced_left
321- idxs_to_keep = (* (slice (None ),) * new_axis , idx_to_keep )
322- new_out = opt_sm [idxs_to_keep ]
323- copy_stack_trace (old_out , new_out )
324- return [new_out ]
325-
326- else :
327- return None
344+ return _lift_subtensor_non_axis (
345+ local_subtensor_lift_rewrite = local_subtensor_of_softmax ,
346+ fgraph = fgraph ,
347+ variable = sm ,
348+ idx_tuple = idx_tuple ,
349+ axis = axis ,
350+ old_subtensor_variable = old_out ,
351+ )
328352
329353 # Index input to softmax
330354 x_sub = x [idx_tuple ]
@@ -696,6 +720,52 @@ def local_subtensor_make_vector(fgraph, node):
696720 pass
697721
698722
723+ @register_canonicalize
724+ @register_specialize
725+ @node_rewriter ([Subtensor ])
726+ def local_subtensor_of_join (fgraph , node ):
727+ """Lift a Subtensor through a Join.
728+
729+ join(axis=1, x, y)[0] -> join(axis=0, x[0], y[0])
730+ join(axis=1, x, y)[:, 0, -1] -> join(axis=1, x[:, :, -1], y[:, :, -1])[:, 0]
731+
732+ """
733+ join_var , * idx = node .inputs
734+
735+ if not (join_var .owner and isinstance (join_var .owner .op , Join )):
736+ return None
737+
738+ if len (fgraph .clients [join_var ]) > 1 :
739+ # Join involves a full_copy, so we don't want to do it twice
740+ return None
741+
742+ join_axis , * join_components = join_var .owner .inputs
743+
744+ # Rewrite only works when the join axis is a constant along a non-indexed dimension
745+ if not isinstance (join_axis , Constant ):
746+ return None
747+
748+ [old_out ] = node .outputs
749+ axis = normalize_axis_index (join_axis .data , join_components [0 ].type .ndim )
750+ idx_tuple = indices_from_subtensor (idx , node .op .idx_list )
751+ if _axis_is_indexed_by_basic_index (idx_tuple , axis ):
752+ return _lift_subtensor_non_axis (
753+ local_subtensor_lift_rewrite = local_subtensor_of_join ,
754+ fgraph = fgraph ,
755+ variable = join_var ,
756+ idx_tuple = idx_tuple ,
757+ axis = axis ,
758+ old_subtensor_variable = old_out ,
759+ )
760+
761+ # Lift index to the Join components
762+ indexed_components = [component [idx_tuple ] for component in join_components ]
763+ new_axis = axis - _ndim_dropped_left_of_axis_by_basic_index (idx_tuple , axis )
764+ out = join (new_axis , * indexed_components )
765+
766+ return [out ]
767+
768+
699769@register_specialize
700770@register_canonicalize
701771@node_rewriter ([Subtensor ])
0 commit comments