Skip to content

Commit 2d35d6c

Browse files
committed
Lift Subtensor over Join
1 parent c20a47b commit 2d35d6c

File tree

2 files changed

+150
-26
lines changed

2 files changed

+150
-26
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 95 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
from collections.abc import Iterable, Sequence
2+
from typing import cast
23

34
import numpy as np
45

56
from pytensor import Variable
6-
from pytensor.graph import Constant, node_rewriter
7-
from pytensor.graph.rewriting.basic import copy_stack_trace
7+
from pytensor.graph import Constant, FunctionGraph, node_rewriter
8+
from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace
89
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
910
from pytensor.scalar import basic as ps
1011
from pytensor.tensor.basic import (
1112
Alloc,
13+
Join,
1214
MakeVector,
1315
alloc,
1416
as_tensor,
1517
expand_dims,
1618
get_underlying_scalar_constant_value,
19+
join,
1720
register_infer_shape,
1821
)
1922
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
@@ -46,6 +49,7 @@
4649
)
4750
from pytensor.tensor.type import TensorType
4851
from pytensor.tensor.type_other import SliceType
52+
from pytensor.tensor.variable import TensorVariable
4953

5054

5155
def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]:
@@ -68,6 +72,41 @@ def _axis_is_indexed_by_basic_index(
6872
return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis)
6973

7074

75+
def _lift_subtensor_non_axis(
76+
local_subtensor_lift_rewrite: NodeRewriter,
77+
fgraph: FunctionGraph,
78+
variable: TensorVariable,
79+
idx_tuple: tuple[int | slice],
80+
axis: int,
81+
old_subtensor_variable: TensorVariable,
82+
) -> None | list[TensorVariable]:
83+
# Apply generic subtensor lift rewrite along "non-axis" dimensions
84+
real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)]
85+
if len(real_indices) > 1 and variable.type.ndim > 1:
86+
# Split the subtensor
87+
idx_to_keep = idx_tuple[axis]
88+
idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :])
89+
90+
# Lift the non-axis indexes by calling the rewrite itself
91+
indexed_variable = variable[idxs_to_lift]
92+
[indexed_variable] = cast(
93+
list[TensorVariable],
94+
local_subtensor_lift_rewrite.transform(fgraph, indexed_variable.owner),
95+
)
96+
copy_stack_trace([old_subtensor_variable, indexed_variable], indexed_variable)
97+
98+
# Then reintroduce the axis index
99+
ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index(idx_tuple, axis)
100+
new_axis = axis - ndim_reduced_left
101+
idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep)
102+
new_out = indexed_variable[idxs_to_keep]
103+
copy_stack_trace(old_subtensor_variable, new_out)
104+
return [new_out]
105+
106+
else:
107+
return None
108+
109+
71110
@register_canonicalize
72111
@register_stabilize
73112
@register_specialize
@@ -299,29 +338,14 @@ def local_subtensor_of_softmax(fgraph, node):
299338
if _axis_is_indexed_by_basic_index(idx_tuple, axis):
300339
# If there are more dimensions being indexed, we can split them
301340
# And lift the non-axis indexes while keeping the axis index
302-
real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)]
303-
if len(real_indices) > 1 and sm.type.ndim > 1:
304-
# Split the subtensor
305-
idx_to_keep = idx_tuple[axis]
306-
idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :])
307-
308-
# Lift the non-axis indexes by calling the rewrite itself
309-
opt_sm = sm[idxs_to_lift]
310-
[opt_sm] = local_subtensor_of_softmax.transform(fgraph, opt_sm.owner)
311-
copy_stack_trace([old_out, sm], opt_sm)
312-
313-
# Then reintroduce the axis index
314-
ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index(
315-
idx_tuple, axis
316-
)
317-
new_axis = axis - ndim_reduced_left
318-
idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep)
319-
new_out = opt_sm[idxs_to_keep]
320-
copy_stack_trace(old_out, new_out)
321-
return [new_out]
322-
323-
else:
324-
return None
341+
return _lift_subtensor_non_axis(
342+
local_subtensor_lift_rewrite=local_subtensor_of_softmax,
343+
fgraph=fgraph,
344+
variable=sm,
345+
idx_tuple=idx_tuple,
346+
axis=axis,
347+
old_subtensor_variable=old_out,
348+
)
325349

326350
# Index input to softmax
327351
x_sub = x[idx_tuple]
@@ -693,6 +717,52 @@ def local_subtensor_make_vector(fgraph, node):
693717
pass
694718

695719

720+
@register_canonicalize
721+
@register_specialize
722+
@node_rewriter([Subtensor])
723+
def local_subtensor_of_join(fgraph, node):
724+
"""Lift a Subtensor through a Join.
725+
726+
join(axis=1, x, y)[0] -> join(axis=0, x[0], y[0])
727+
join(axis=1, x, y)[:, 0, -1] -> join(axis=1, x[:, :, -1], y[:, :, -1])[:, 0]
728+
729+
"""
730+
join_var, *idx = node.inputs
731+
732+
if not (join_var.owner and isinstance(join_var.owner.op, Join)):
733+
return None
734+
735+
if len(fgraph.clients[join_var]) > 1:
736+
# Join involves a full_copy, so we don't want to do it twice
737+
return None
738+
739+
join_axis, *join_components = join_var.owner.inputs
740+
741+
# Rewrite only works when the join axis is a constant along a non-indexed dimension
742+
if not isinstance(join_axis, Constant):
743+
return None
744+
745+
[old_out] = node.outputs
746+
axis = normalize_axis_index(join_axis.data, join_components[0].type.ndim)
747+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
748+
if _axis_is_indexed_by_basic_index(idx_tuple, axis):
749+
return _lift_subtensor_non_axis(
750+
local_subtensor_lift_rewrite=local_subtensor_of_join,
751+
fgraph=fgraph,
752+
variable=join_var,
753+
idx_tuple=idx_tuple,
754+
axis=axis,
755+
old_subtensor_variable=old_out,
756+
)
757+
758+
# Lift index to the Join components
759+
indexed_components = [component[idx_tuple] for component in join_components]
760+
new_axis = axis - _ndim_dropped_left_of_axis_by_basic_index(idx_tuple, axis)
761+
out = join(new_axis, *indexed_components)
762+
763+
return [out]
764+
765+
696766
@register_specialize
697767
@register_canonicalize
698768
@node_rewriter([Subtensor])

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
tensor3,
3838
vector,
3939
)
40-
from pytensor.tensor.basic import MakeVector, expand_dims, make_vector
40+
from pytensor.tensor.basic import MakeVector, concatenate, expand_dims, make_vector
4141
from pytensor.tensor.elemwise import DimShuffle, Elemwise
4242
from pytensor.tensor.math import sum as pt_sum
4343
from pytensor.tensor.rewriting.subtensor_lift import (
@@ -251,6 +251,9 @@ def test_local_subtensor_of_softmax(original_fn, expected_fn):
251251
)
252252

253253

254+
shared_axis = shared(1, "axis")
255+
256+
254257
def test_local_subtensor_of_unbroadcast():
255258
# Test that Subtensor(Unbroadcast(x)) gets optimized into
256259
# Unbroadcast(Subtensor(x)).
@@ -660,6 +663,57 @@ def test_empty_subtensor(self):
660663
assert local_subtensor_make_vector.transform(fgraph, node) == [v]
661664

662665

666+
@pytest.mark.parametrize(
667+
"original_fn, expected_fn",
668+
[
669+
(
670+
lambda x, y: concatenate([x, y], axis=1)[1],
671+
lambda x, y: concatenate([x[1], y[1]], axis=0),
672+
),
673+
(
674+
lambda x, y: concatenate([x, y], axis=-1)[1:],
675+
lambda x, y: concatenate([x[1:], y[1:]], axis=1),
676+
),
677+
# Indexing on both axis of concatenation and somewhere else:
678+
(
679+
lambda x, y: concatenate([x, y], axis=1)[0, 1:],
680+
lambda x, y: concatenate([x[0], y[0]], axis=0)[1:],
681+
),
682+
# Not supported, indexing on axis of concatenation
683+
(
684+
lambda x, y: concatenate([x, y], axis=0)[0],
685+
lambda x, y: concatenate([x, y], axis=0)[0],
686+
),
687+
(
688+
lambda x, y: concatenate([x, y], axis=1)[:, 1:],
689+
lambda x, y: concatenate([x, y], axis=1)[:, 1:],
690+
),
691+
# Not supported, axis of concatenation is dynamically determined
692+
(
693+
lambda x, y: concatenate([x, y], axis=shared_axis)[1],
694+
lambda x, y: concatenate([x, y], axis=shared_axis)[1],
695+
),
696+
],
697+
)
698+
def test_local_subtensor_of_join(original_fn, expected_fn):
699+
rng = np.random.default_rng(257)
700+
x = pt.matrix("x", shape=(5, 3))
701+
y = pt.matrix("y", shape=(5, 3))
702+
x_test = rng.normal(size=x.type.shape)
703+
y_test = rng.normal(size=y.type.shape)
704+
705+
out = original_fn(x, y)
706+
expected_opt_out = expected_fn(x, y)
707+
opt_out = rewrite_graph(out)
708+
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
709+
[expected_opt_out, opt_out], print_type=True
710+
)
711+
np.testing.assert_allclose(
712+
opt_out.eval({x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE),
713+
out.eval({x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE),
714+
)
715+
716+
663717
def test_local_subtensor_shape_constant():
664718
x = tensor(dtype=np.float64, shape=(1, None)).shape[0]
665719
(res,) = local_subtensor_shape_constant.transform(None, x.owner)

0 commit comments

Comments
 (0)