11from collections .abc import Iterable , Sequence
2+ from typing import cast
23
34import numpy as np
45from numpy .core .numeric import ( # type: ignore
78)
89
910from pytensor import Variable
10- from pytensor .graph import Constant , node_rewriter
11- from pytensor .graph .rewriting .basic import copy_stack_trace
11+ from pytensor .graph import Constant , FunctionGraph , node_rewriter
12+ from pytensor .graph .rewriting .basic import NodeRewriter , copy_stack_trace
1213from pytensor .scalar import basic as ps
1314from pytensor .tensor .basic import (
1415 Alloc ,
16+ Join ,
1517 MakeVector ,
1618 alloc ,
1719 as_tensor ,
1820 expand_dims ,
1921 get_underlying_scalar_constant_value ,
22+ join ,
2023 register_infer_shape ,
2124)
2225from pytensor .tensor .elemwise import CAReduce , DimShuffle , Elemwise
4952)
5053from pytensor .tensor .type import TensorType
5154from pytensor .tensor .type_other import SliceType
55+ from pytensor .tensor .variable import TensorVariable
5256
5357
5458def _dims_dropped_by_basic_index (idxs : Sequence [slice | int ]) -> tuple [int , ...]:
@@ -71,6 +75,41 @@ def _axis_is_indexed_by_basic_index(
7175 return any (ax < len (idxs ) and not is_full_slice (idxs [ax ]) for ax in axis )
7276
7377
78+ def _lift_subtensor_non_axis (
79+ local_subtensor_lift_rewrite : NodeRewriter ,
80+ fgraph : FunctionGraph ,
81+ variable : TensorVariable ,
82+ idx_tuple : tuple [int | slice ],
83+ axis : int ,
84+ old_subtensor_variable : TensorVariable ,
85+ ) -> None | list [TensorVariable ]:
86+ # Apply generic subtensor lift rewrite along "non-axis" dimensions
87+ real_indices = [idx for idx in idx_tuple if not is_full_slice (idx )]
88+ if len (real_indices ) > 1 and variable .type .ndim > 1 :
89+ # Split the subtensor
90+ idx_to_keep = idx_tuple [axis ]
91+ idxs_to_lift = (* idx_tuple [:axis ], slice (None ), * idx_tuple [axis + 1 :])
92+
93+ # Lift the non-axis indexes by calling the rewrite itself
94+ indexed_variable = variable [idxs_to_lift ]
95+ [indexed_variable ] = cast (
96+ list [TensorVariable ],
97+ local_subtensor_lift_rewrite .transform (fgraph , indexed_variable .owner ),
98+ )
99+ copy_stack_trace ([old_subtensor_variable , indexed_variable ], indexed_variable )
100+
101+ # Then reintroduce the axis index
102+ ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index (idx_tuple , axis )
103+ new_axis = axis - ndim_reduced_left
104+ idxs_to_keep = (* (slice (None ),) * new_axis , idx_to_keep )
105+ new_out = indexed_variable [idxs_to_keep ]
106+ copy_stack_trace (old_subtensor_variable , new_out )
107+ return [new_out ]
108+
109+ else :
110+ return None
111+
112+
74113@register_canonicalize
75114@register_stabilize
76115@register_specialize
@@ -302,29 +341,14 @@ def local_subtensor_of_softmax(fgraph, node):
302341 if _axis_is_indexed_by_basic_index (idx_tuple , axis ):
303342 # If there are more dimensions being indexed, we can split them
304343 # And lift the non-axis indexes while keeping the axis index
305- real_indices = [idx for idx in idx_tuple if not is_full_slice (idx )]
306- if len (real_indices ) > 1 and sm .type .ndim > 1 :
307- # Split the subtensor
308- idx_to_keep = idx_tuple [axis ]
309- idxs_to_lift = (* idx_tuple [:axis ], slice (None ), * idx_tuple [axis + 1 :])
310-
311- # Lift the non-axis indexes by calling the rewrite itself
312- opt_sm = sm [idxs_to_lift ]
313- [opt_sm ] = local_subtensor_of_softmax .transform (fgraph , opt_sm .owner )
314- copy_stack_trace ([old_out , sm ], opt_sm )
315-
316- # Then reintroduce the axis index
317- ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index (
318- idx_tuple , axis
319- )
320- new_axis = axis - ndim_reduced_left
321- idxs_to_keep = (* (slice (None ),) * new_axis , idx_to_keep )
322- new_out = opt_sm [idxs_to_keep ]
323- copy_stack_trace (old_out , new_out )
324- return [new_out ]
325-
326- else :
327- return None
344+ return _lift_subtensor_non_axis (
345+ local_subtensor_lift_rewrite = local_subtensor_of_softmax ,
346+ fgraph = fgraph ,
347+ variable = sm ,
348+ idx_tuple = idx_tuple ,
349+ axis = axis ,
350+ old_subtensor_variable = old_out ,
351+ )
328352
329353 # Index input to softmax
330354 x_sub = x [idx_tuple ]
@@ -695,6 +719,52 @@ def local_subtensor_make_vector(fgraph, node):
695719 pass
696720
697721
722+ @register_canonicalize
723+ @register_specialize
724+ @node_rewriter ([Subtensor ])
725+ def local_subtensor_of_join (fgraph , node ):
726+ """Lift a Subtensor through a Join.
727+
728+ join(axis=1, x, y)[0] -> join(axis=0, x[0], y[0])
729+ join(axis=1, x, y)[:, 0, -1] -> join(axis=1, x[:, :, -1], y[:, :, -1])[:, 0]
730+
731+ """
732+ join_var , * idx = node .inputs
733+
734+ if not (join_var .owner and isinstance (join_var .owner .op , Join )):
735+ return None
736+
737+ if len (fgraph .clients [join_var ]) > 1 :
738+ # Join involves a full_copy, so we don't want to do it twice
739+ return None
740+
741+ join_axis , * join_components = join_var .owner .inputs
742+
743+ # Rewrite only works when the join axis is a constant along a non-indexed dimension
744+ if not isinstance (join_axis , Constant ):
745+ return None
746+
747+ [old_out ] = node .outputs
748+ axis = normalize_axis_index (join_axis .data , join_components [0 ].type .ndim )
749+ idx_tuple = indices_from_subtensor (idx , node .op .idx_list )
750+ if _axis_is_indexed_by_basic_index (idx_tuple , axis ):
751+ return _lift_subtensor_non_axis (
752+ local_subtensor_lift_rewrite = local_subtensor_of_join ,
753+ fgraph = fgraph ,
754+ variable = join_var ,
755+ idx_tuple = idx_tuple ,
756+ axis = axis ,
757+ old_subtensor_variable = old_out ,
758+ )
759+
760+ # Lift index to the Join components
761+ indexed_components = [component [idx_tuple ] for component in join_components ]
762+ new_axis = axis - _ndim_dropped_left_of_axis_by_basic_index (idx_tuple , axis )
763+ out = join (new_axis , * indexed_components )
764+
765+ return [out ]
766+
767+
698768@register_specialize
699769@register_canonicalize
700770@node_rewriter ([Subtensor ])
0 commit comments