Skip to content

Commit 2c28177

Browse files
committed
Lift Subtensor over Join
1 parent f24a974 commit 2c28177

File tree

2 files changed

+150
-26
lines changed

2 files changed

+150
-26
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 95 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Iterable, Sequence
2+
from typing import cast
23

34
import numpy as np
45
from numpy.core.numeric import ( # type: ignore
@@ -7,16 +8,18 @@
78
)
89

910
from pytensor import Variable
10-
from pytensor.graph import Constant, node_rewriter
11-
from pytensor.graph.rewriting.basic import copy_stack_trace
11+
from pytensor.graph import Constant, FunctionGraph, node_rewriter
12+
from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace
1213
from pytensor.scalar import basic as ps
1314
from pytensor.tensor.basic import (
1415
Alloc,
16+
Join,
1517
MakeVector,
1618
alloc,
1719
as_tensor,
1820
expand_dims,
1921
get_underlying_scalar_constant_value,
22+
join,
2023
register_infer_shape,
2124
)
2225
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
@@ -49,6 +52,7 @@
4952
)
5053
from pytensor.tensor.type import TensorType
5154
from pytensor.tensor.type_other import SliceType
55+
from pytensor.tensor.variable import TensorVariable
5256

5357

5458
def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]:
@@ -71,6 +75,41 @@ def _axis_is_indexed_by_basic_index(
7175
return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis)
7276

7377

78+
def _lift_subtensor_non_axis(
79+
local_subtensor_lift_rewrite: NodeRewriter,
80+
fgraph: FunctionGraph,
81+
variable: TensorVariable,
82+
idx_tuple: tuple[int | slice],
83+
axis: int,
84+
old_subtensor_variable: TensorVariable,
85+
) -> None | list[TensorVariable]:
86+
# Apply generic subtensor lift rewrite along "non-axis" dimensions
87+
real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)]
88+
if len(real_indices) > 1 and variable.type.ndim > 1:
89+
# Split the subtensor
90+
idx_to_keep = idx_tuple[axis]
91+
idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :])
92+
93+
# Lift the non-axis indexes by calling the rewrite itself
94+
indexed_variable = variable[idxs_to_lift]
95+
[indexed_variable] = cast(
96+
list[TensorVariable],
97+
local_subtensor_lift_rewrite.transform(fgraph, indexed_variable.owner),
98+
)
99+
copy_stack_trace([old_subtensor_variable, indexed_variable], indexed_variable)
100+
101+
# Then reintroduce the axis index
102+
ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index(idx_tuple, axis)
103+
new_axis = axis - ndim_reduced_left
104+
idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep)
105+
new_out = indexed_variable[idxs_to_keep]
106+
copy_stack_trace(old_subtensor_variable, new_out)
107+
return [new_out]
108+
109+
else:
110+
return None
111+
112+
74113
@register_canonicalize
75114
@register_stabilize
76115
@register_specialize
@@ -302,29 +341,14 @@ def local_subtensor_of_softmax(fgraph, node):
302341
if _axis_is_indexed_by_basic_index(idx_tuple, axis):
303342
# If there are more dimensions being indexed, we can split them
304343
# And lift the non-axis indexes while keeping the axis index
305-
real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)]
306-
if len(real_indices) > 1 and sm.type.ndim > 1:
307-
# Split the subtensor
308-
idx_to_keep = idx_tuple[axis]
309-
idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :])
310-
311-
# Lift the non-axis indexes by calling the rewrite itself
312-
opt_sm = sm[idxs_to_lift]
313-
[opt_sm] = local_subtensor_of_softmax.transform(fgraph, opt_sm.owner)
314-
copy_stack_trace([old_out, sm], opt_sm)
315-
316-
# Then reintroduce the axis index
317-
ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index(
318-
idx_tuple, axis
319-
)
320-
new_axis = axis - ndim_reduced_left
321-
idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep)
322-
new_out = opt_sm[idxs_to_keep]
323-
copy_stack_trace(old_out, new_out)
324-
return [new_out]
325-
326-
else:
327-
return None
344+
return _lift_subtensor_non_axis(
345+
local_subtensor_lift_rewrite=local_subtensor_of_softmax,
346+
fgraph=fgraph,
347+
variable=sm,
348+
idx_tuple=idx_tuple,
349+
axis=axis,
350+
old_subtensor_variable=old_out,
351+
)
328352

329353
# Index input to softmax
330354
x_sub = x[idx_tuple]
@@ -696,6 +720,52 @@ def local_subtensor_make_vector(fgraph, node):
696720
pass
697721

698722

723+
@register_canonicalize
724+
@register_specialize
725+
@node_rewriter([Subtensor])
726+
def local_subtensor_of_join(fgraph, node):
727+
"""Lift a Subtensor through a Join.
728+
729+
join(axis=1, x, y)[0] -> join(axis=0, x[0], y[0])
730+
join(axis=1, x, y)[:, 0, -1] -> join(axis=1, x[:, :, -1], y[:, :, -1])[:, 0]
731+
732+
"""
733+
join_var, *idx = node.inputs
734+
735+
if not (join_var.owner and isinstance(join_var.owner.op, Join)):
736+
return None
737+
738+
if len(fgraph.clients[join_var]) > 1:
739+
# Join involves a full_copy, so we don't want to do it twice
740+
return None
741+
742+
join_axis, *join_components = join_var.owner.inputs
743+
744+
# Rewrite only works when the join axis is a constant along a non-indexed dimension
745+
if not isinstance(join_axis, Constant):
746+
return None
747+
748+
[old_out] = node.outputs
749+
axis = normalize_axis_index(join_axis.data, join_components[0].type.ndim)
750+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
751+
if _axis_is_indexed_by_basic_index(idx_tuple, axis):
752+
return _lift_subtensor_non_axis(
753+
local_subtensor_lift_rewrite=local_subtensor_of_join,
754+
fgraph=fgraph,
755+
variable=join_var,
756+
idx_tuple=idx_tuple,
757+
axis=axis,
758+
old_subtensor_variable=old_out,
759+
)
760+
761+
# Lift index to the Join components
762+
indexed_components = [component[idx_tuple] for component in join_components]
763+
new_axis = axis - _ndim_dropped_left_of_axis_by_basic_index(idx_tuple, axis)
764+
out = join(new_axis, *indexed_components)
765+
766+
return [out]
767+
768+
699769
@register_specialize
700770
@register_canonicalize
701771
@node_rewriter([Subtensor])

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
tensor3,
4343
vector,
4444
)
45-
from pytensor.tensor.basic import MakeVector, expand_dims, make_vector
45+
from pytensor.tensor.basic import MakeVector, concatenate, expand_dims, make_vector
4646
from pytensor.tensor.elemwise import DimShuffle, Elemwise
4747
from pytensor.tensor.math import sum as pt_sum
4848
from pytensor.tensor.rewriting.subtensor_lift import (
@@ -258,6 +258,9 @@ def test_local_subtensor_of_softmax(original_fn, expected_fn):
258258
)
259259

260260

261+
shared_axis = shared(1, "axis")
262+
263+
261264
def test_local_subtensor_of_unbroadcast():
262265
# Test that Subtensor(Unbroadcast(x)) gets optimized into
263266
# Unbroadcast(Subtensor(x)).
@@ -667,6 +670,57 @@ def test_empty_subtensor(self):
667670
assert local_subtensor_make_vector.transform(fgraph, node) == [v]
668671

669672

673+
@pytest.mark.parametrize(
674+
"original_fn, expected_fn",
675+
[
676+
(
677+
lambda x, y: concatenate([x, y], axis=1)[1],
678+
lambda x, y: concatenate([x[1], y[1]], axis=0),
679+
),
680+
(
681+
lambda x, y: concatenate([x, y], axis=-1)[1:],
682+
lambda x, y: concatenate([x[1:], y[1:]], axis=1),
683+
),
684+
# Indexing on both axis of concatenation and somewhere else:
685+
(
686+
lambda x, y: concatenate([x, y], axis=1)[0, 1:],
687+
lambda x, y: concatenate([x[0], y[0]], axis=0)[1:],
688+
),
689+
# Not supported, indexing on axis of concatenation
690+
(
691+
lambda x, y: concatenate([x, y], axis=0)[0],
692+
lambda x, y: concatenate([x, y], axis=0)[0],
693+
),
694+
(
695+
lambda x, y: concatenate([x, y], axis=1)[:, 1:],
696+
lambda x, y: concatenate([x, y], axis=1)[:, 1:],
697+
),
698+
# Not supported, axis of concatenation is dynamically determined
699+
(
700+
lambda x, y: concatenate([x, y], axis=shared_axis)[1],
701+
lambda x, y: concatenate([x, y], axis=shared_axis)[1],
702+
),
703+
],
704+
)
705+
def test_local_subtensor_of_join(original_fn, expected_fn):
706+
rng = np.random.default_rng(257)
707+
x = pt.matrix("x", shape=(5, 3))
708+
y = pt.matrix("y", shape=(5, 3))
709+
x_test = rng.normal(size=x.type.shape)
710+
y_test = rng.normal(size=y.type.shape)
711+
712+
out = original_fn(x, y)
713+
expected_opt_out = expected_fn(x, y)
714+
opt_out = rewrite_graph(out)
715+
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
716+
[expected_opt_out, opt_out], print_type=True
717+
)
718+
np.testing.assert_allclose(
719+
opt_out.eval({x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE),
720+
out.eval({x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE),
721+
)
722+
723+
670724
def test_local_subtensor_shape_constant():
671725
x = tensor(dtype=np.float64, shape=(1, None)).shape[0]
672726
(res,) = local_subtensor_shape_constant.transform(None, x.owner)

0 commit comments

Comments
 (0)