Skip to content

Commit b89487d

Browse files
committed
Lift Subtensor over Join
1 parent eb11f0f commit b89487d

File tree

2 files changed

+153
-27
lines changed

2 files changed

+153
-27
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 98 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
55

66
from pytensor import Variable
7-
from pytensor.graph import Constant, node_rewriter
8-
from pytensor.graph.rewriting.basic import copy_stack_trace
7+
from pytensor.graph import Constant, FunctionGraph, node_rewriter
8+
from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace
99
from pytensor.scalar import basic as ps
1010
from pytensor.tensor import (
1111
TensorType,
@@ -17,7 +17,14 @@
1717
specify_shape,
1818
squeeze,
1919
)
20-
from pytensor.tensor.basic import Alloc, MakeVector, expand_dims, register_infer_shape
20+
from pytensor.tensor.basic import (
21+
Alloc,
22+
Join,
23+
MakeVector,
24+
expand_dims,
25+
join,
26+
register_infer_shape,
27+
)
2128
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
2229
from pytensor.tensor.exceptions import NotScalarConstantError
2330
from pytensor.tensor.math import Dot
@@ -60,6 +67,40 @@ def _axis_is_indexed_by_basic_index(
6067
return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis)
6168

6269

70+
def _lift_subtensor_non_axis(
71+
local_subtensor_lift_rewrite: NodeRewriter,
72+
fgraph: FunctionGraph,
73+
variable: Variable,
74+
idx_tuple: tuple[int | slice],
75+
axis: int,
76+
old_subtensor_variable: Variable,
77+
) -> None | list[Variable]:
78+
# Apply generic subtensor lift rewrite along "non-axis" dimensions
79+
real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)]
80+
if len(real_indices) > 1 and variable.type.ndim > 1:
81+
# Split the subtensor
82+
idx_to_keep = idx_tuple[axis]
83+
idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :])
84+
85+
# Lift the non-axis indexes by calling the rewrite itself
86+
indexed_variable = variable[idxs_to_lift]
87+
[indexed_variable] = local_subtensor_lift_rewrite.transform(
88+
fgraph, indexed_variable.owner
89+
)
90+
copy_stack_trace([old_subtensor_variable, indexed_variable], indexed_variable)
91+
92+
# Then reintroduce the axis index
93+
ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index(idx_tuple, axis)
94+
new_axis = axis - ndim_reduced_left
95+
idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep)
96+
new_out = indexed_variable[idxs_to_keep]
97+
copy_stack_trace(old_subtensor_variable, new_out)
98+
return [new_out]
99+
100+
else:
101+
return None
102+
103+
63104
@register_canonicalize
64105
@register_stabilize
65106
@register_specialize
@@ -291,29 +332,14 @@ def local_subtensor_of_softmax(fgraph, node):
291332
if _axis_is_indexed_by_basic_index(idx_tuple, axis):
292333
# If there are more dimensions being indexed, we can split them
293334
# And lift the non-axis indexes while keeping the axis index
294-
real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)]
295-
if len(real_indices) > 1 and sm.type.ndim > 1:
296-
# Split the subtensor
297-
idx_to_keep = idx_tuple[axis]
298-
idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :])
299-
300-
# Lift the non-axis indexes by calling the rewrite itself
301-
opt_sm = sm[idxs_to_lift]
302-
[opt_sm] = local_subtensor_of_softmax.transform(fgraph, opt_sm.owner)
303-
copy_stack_trace([old_out, sm], opt_sm)
304-
305-
# Then reintroduce the axis index
306-
ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index(
307-
idx_tuple, axis
308-
)
309-
new_axis = axis - ndim_reduced_left
310-
idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep)
311-
new_out = opt_sm[idxs_to_keep]
312-
copy_stack_trace(old_out, new_out)
313-
return [new_out]
314-
315-
else:
316-
return None
335+
return _lift_subtensor_non_axis(
336+
local_subtensor_lift_rewrite=local_subtensor_of_softmax,
337+
fgraph=fgraph,
338+
variable=sm,
339+
idx_tuple=idx_tuple,
340+
axis=axis,
341+
old_subtensor_variable=old_out,
342+
)
317343

318344
# Index input to softmax
319345
x_sub = x[idx_tuple]
@@ -684,6 +710,52 @@ def local_subtensor_make_vector(fgraph, node):
684710
pass
685711

686712

713+
@register_canonicalize
714+
@register_specialize
715+
@node_rewriter([Subtensor])
716+
def local_subtensor_of_join(fgraph, node):
717+
"""Lift a Subtensor through a Join.
718+
719+
join(axis=1, x, y)[0] -> join(axis=0, x[0], y[0])
720+
join(axis=1, x, y)[:, 0, -1] -> join(axis=1, x[:, :, -1], y[:, :, -1])[:, 0]
721+
722+
"""
723+
join_var, *idx = node.inputs
724+
725+
if not (join_var.owner and isinstance(join_var.owner.op, Join)):
726+
return None
727+
728+
if len(fgraph.clients[join_var]) > 1:
729+
# Join involves a full_copy, so we don't want to do it twice
730+
return None
731+
732+
join_axis, *join_components = join_var.owner.inputs
733+
734+
# Rewrite only works when the join axis is a constant along a non-indexed dimension
735+
if not isinstance(join_axis, Constant):
736+
return None
737+
738+
[old_out] = node.outputs
739+
axis = normalize_axis_index(join_axis.data, join_components[0].type.ndim)
740+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
741+
if _axis_is_indexed_by_basic_index(idx_tuple, axis):
742+
return _lift_subtensor_non_axis(
743+
local_subtensor_lift_rewrite=local_subtensor_of_join,
744+
fgraph=fgraph,
745+
variable=join_var,
746+
idx_tuple=idx_tuple,
747+
axis=axis,
748+
old_subtensor_variable=old_out,
749+
)
750+
751+
# Lift index to the Join components
752+
indexed_components = [component[idx_tuple] for component in join_components]
753+
new_axis = axis - _ndim_dropped_left_of_axis_by_basic_index(idx_tuple, axis)
754+
out = join(new_axis, *indexed_components)
755+
756+
return [out]
757+
758+
687759
@register_specialize
688760
@register_canonicalize
689761
@node_rewriter([Subtensor])

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
tensor3,
4343
vector,
4444
)
45-
from pytensor.tensor.basic import MakeVector, expand_dims, make_vector
45+
from pytensor.tensor.basic import MakeVector, concatenate, expand_dims, make_vector
4646
from pytensor.tensor.elemwise import DimShuffle, Elemwise
4747
from pytensor.tensor.math import sum as pt_sum
4848
from pytensor.tensor.rewriting.subtensor_lift import (
@@ -252,6 +252,9 @@ def test_local_subtensor_of_softmax(original_fn, expected_fn):
252252
)
253253

254254

255+
shared_axis = shared(1, "axis")
256+
257+
255258
def test_local_subtensor_of_unbroadcast():
256259
# Test that Subtensor(Unbroadcast(x)) gets optimized into
257260
# Unbroadcast(Subtensor(x)).
@@ -661,6 +664,57 @@ def test_empty_subtensor(self):
661664
assert local_subtensor_make_vector.transform(fgraph, node) == [v]
662665

663666

667+
@pytest.mark.parametrize(
668+
"original_fn, expected_fn",
669+
[
670+
(
671+
lambda x, y: concatenate([x, y], axis=1)[1],
672+
lambda x, y: concatenate([x[1], y[1]], axis=0),
673+
),
674+
(
675+
lambda x, y: concatenate([x, y], axis=-1)[1:],
676+
lambda x, y: concatenate([x[1:], y[1:]], axis=1),
677+
),
678+
# Indexing on both axis of concatenation and somewhere else:
679+
(
680+
lambda x, y: concatenate([x, y], axis=1)[0, 1:],
681+
lambda x, y: concatenate([x[0], y[0]], axis=0)[1:],
682+
),
683+
# Not supported, indexing on axis of concatenation
684+
(
685+
lambda x, y: concatenate([x, y], axis=0)[0],
686+
lambda x, y: concatenate([x, y], axis=0)[0],
687+
),
688+
(
689+
lambda x, y: concatenate([x, y], axis=1)[:, 1:],
690+
lambda x, y: concatenate([x, y], axis=1)[:, 1:],
691+
),
692+
# Not supported, axis of concatenation is dynamically determined
693+
(
694+
lambda x, y: concatenate([x, y], axis=shared_axis)[1],
695+
lambda x, y: concatenate([x, y], axis=shared_axis)[1],
696+
),
697+
],
698+
)
699+
def test_local_subtensor_of_join(original_fn, expected_fn):
700+
rng = np.random.default_rng(257)
701+
x = pt.matrix("x", shape=(5, 3))
702+
y = pt.matrix("y", shape=(5, 3))
703+
x_test = rng.normal(size=x.type.shape)
704+
y_test = rng.normal(size=y.type.shape)
705+
706+
out = original_fn(x, y)
707+
expected_opt_out = expected_fn(x, y)
708+
opt_out = rewrite_graph(out)
709+
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
710+
[expected_opt_out, opt_out], print_type=True
711+
)
712+
np.testing.assert_allclose(
713+
opt_out.eval({x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE),
714+
out.eval({x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE),
715+
)
716+
717+
664718
def test_local_subtensor_shape_constant():
665719
x = tensor(dtype=np.float64, shape=(1, None)).shape[0]
666720
(res,) = local_subtensor_shape_constant.transform(None, x.owner)

0 commit comments

Comments
 (0)