Skip to content

Commit 2ad449b

Browse files
committed
Lift Subtensor over expand_dims
1 parent 19550e8 commit 2ad449b

File tree

3 files changed

+157
-19
lines changed

3 files changed

+157
-19
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
indices_from_subtensor,
7575
)
7676
from pytensor.tensor.type import TensorType, integer_dtypes
77-
from pytensor.tensor.type_other import NoneTypeT, SliceConstant, SliceType
77+
from pytensor.tensor.type_other import NoneTypeT, SliceType
7878
from pytensor.tensor.variable import TensorConstant, TensorVariable
7979

8080

@@ -154,19 +154,17 @@ def transform_take(a, indices, axis):
154154

155155
def is_full_slice(x):
156156
"""Determine if `x` is a ``slice(None)`` or a symbolic equivalent."""
157-
if (
158-
(isinstance(x, slice) and x == slice(None))
159-
or (isinstance(x, SliceConstant) and x.value == slice(None))
160-
or (
161-
not isinstance(x, SliceConstant)
162-
and isinstance(getattr(x, "type", None), SliceType)
163-
and x.owner is not None
164-
and all(
165-
isinstance(getattr(i, "type", None), NoneTypeT) for i in x.owner.inputs
166-
)
167-
)
168-
):
169-
return True
157+
if isinstance(x, slice):
158+
return x == slice(None)
159+
160+
if isinstance(x, Variable) and isinstance(x.type, SliceType):
161+
if isinstance(x, Constant):
162+
return x.data == slice(None)
163+
else:
164+
# Symbolic MakeSlice
165+
# Ignores start = 0, step = 1 cases
166+
return all(isinstance(i.type, NoneTypeT) for i in x.owner.inputs)
167+
170168
return False
171169

172170

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,19 @@
1111
MakeVector,
1212
alloc,
1313
as_tensor,
14+
expand_dims,
1415
get_underlying_scalar_constant_value,
1516
register_infer_shape,
1617
)
17-
from pytensor.tensor.elemwise import Elemwise
18+
from pytensor.tensor.elemwise import DimShuffle, Elemwise
1819
from pytensor.tensor.exceptions import NotScalarConstantError
1920
from pytensor.tensor.math import Dot, ceil_intdiv, dot
2021
from pytensor.tensor.rewriting.basic import (
2122
register_canonicalize,
2223
register_specialize,
2324
register_stabilize,
2425
)
25-
from pytensor.tensor.rewriting.subtensor import register_useless
26+
from pytensor.tensor.rewriting.subtensor import is_full_slice, register_useless
2627
from pytensor.tensor.shape import (
2728
Shape,
2829
SpecifyShape,
@@ -37,6 +38,7 @@
3738
get_canonical_form_slice,
3839
get_constant_idx,
3940
get_idx_list,
41+
indices_from_subtensor,
4042
)
4143
from pytensor.tensor.type import TensorType
4244
from pytensor.tensor.type_other import SliceType
@@ -204,6 +206,80 @@ def local_subtensor_lift(fgraph, node):
204206
return [rbcast_subt_x]
205207

206208

209+
@register_canonicalize("shape_unsafe")
210+
@register_specialize("shape_unsafe")
211+
@node_rewriter([Subtensor])
212+
def local_subtensor_of_expand_dims(fgraph, node):
213+
"""Lift a Subtensor through a DimShuffle that only expands dims.
214+
215+
expand_dims(x, axis=0)[0] -> x
216+
expand_dims(x, axis=0)[:, 0] -> expand_dims(x[0], axis=0)
217+
expand_dims(x, axis=2)[0] -> expand_dims(x[0], axis=1)
218+
219+
This goes beyond `local_subtensor_remove_broadcastable_index` which
220+
simply removes useless subtensors on broadcastable dimensions.
221+
"""
222+
ds, *idx = node.inputs
223+
224+
if not (ds.owner and isinstance(ds.owner.op, DimShuffle)):
225+
return None
226+
227+
ds_op = ds.owner.op
228+
229+
if not ds_op.is_expand_dims:
230+
return None
231+
232+
expanded_axes = ds_op.augment
233+
[x] = ds.owner.inputs
234+
235+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
236+
237+
# Keep indexes for the original dimensions, and drop indexes for the expanded dimensions when safe
238+
new_idxs = []
239+
for i, idx_item in enumerate(idx_tuple):
240+
if i in expanded_axes:
241+
if isinstance(idx_item, slice):
242+
# Slice could be keeping or dropping this dimension
243+
if is_full_slice(idx_item):
244+
# A None slice, always keeps the dimension.
245+
# We skip the index, and later introduce the needed expand_dim
246+
continue
247+
else:
248+
# Other slices could keep or drop the dimension.
249+
# Get out instead o trying to figure out which case it is
250+
return None
251+
else:
252+
# Integer indexing can only drop the dimension (if it's a valid graph)
253+
# We can just drop the index and avoid expanding the dimension
254+
# This is why this rewrite is tagged with "shape_unsafe"
255+
continue
256+
else:
257+
# Keep indexes for non-expanded dimensions
258+
new_idxs.append(idx_item)
259+
260+
[old_out] = node.outputs
261+
out = x[tuple(new_idxs)]
262+
copy_stack_trace(old_out, out)
263+
264+
if out.type.broadcastable != old_out.type.broadcastable:
265+
# Re-introduce needed new dimensions (corresponding to full slices on the original expanded dimensions)
266+
# If out.type.broadcastable == (False) and old_out.type.broadcastable == (True, False, True)
267+
# then axis = (0, 2)
268+
old_bcast = list(old_out.type.broadcastable)
269+
expanded_bcast = list(out.type.broadcastable)
270+
axis = []
271+
i = 0
272+
while i < len(old_bcast):
273+
if i == len(expanded_bcast) or expanded_bcast[i] != old_bcast[i]:
274+
expanded_bcast.insert(i, True)
275+
axis.append(i)
276+
i += 1
277+
out = expand_dims(out, axis=axis)
278+
copy_stack_trace(old_out, out)
279+
280+
return [out]
281+
282+
207283
@register_infer_shape
208284
@register_useless
209285
@register_canonicalize

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
Type,
2525
rewrite_graph,
2626
)
27+
from pytensor.graph.basic import equal_computations
2728
from pytensor.graph.rewriting.basic import check_stack_trace
29+
from pytensor.printing import debugprint
2830
from pytensor.tensor import (
2931
add,
3032
exp,
@@ -43,7 +45,7 @@
4345
tensor3,
4446
vector,
4547
)
46-
from pytensor.tensor.basic import MakeVector, make_vector
48+
from pytensor.tensor.basic import MakeVector, expand_dims, make_vector
4749
from pytensor.tensor.elemwise import DimShuffle, Elemwise
4850
from pytensor.tensor.rewriting.subtensor_lift import (
4951
local_subtensor_make_vector,
@@ -53,6 +55,9 @@
5355
from pytensor.tensor.subtensor import Subtensor
5456

5557

58+
NO_OPTIMIZATION_MODE = Mode(linker="py", optimizer=None)
59+
60+
5661
class TestLocalSubtensorLift:
5762
def test_basic(self):
5863
# basic test that the Op works
@@ -134,8 +139,8 @@ def test_basic_4(self):
134139
assert check_stack_trace(f, ops_to_check="all")
135140

136141
prog = f.maker.fgraph.toposort()
137-
assert isinstance(prog[0].op, DimShuffle)
138-
assert isinstance(prog[1].op, Subtensor)
142+
assert isinstance(prog[0].op, Subtensor)
143+
assert isinstance(prog[1].op, DimShuffle)
139144
assert prog[2].op == exp
140145
assert len(prog) == 3
141146
f([4, 5]) # let debugmode test something
@@ -256,6 +261,65 @@ def test_basic_8(self):
256261
assert (f4(zval) == zval[:, 3, 0]).all()
257262

258263

264+
@pytest.mark.parametrize(
265+
"original_fn, expected_fn",
266+
[
267+
# Integer indexing
268+
(lambda x: expand_dims(x, axis=0)[0], lambda x: x),
269+
(
270+
lambda x: expand_dims(x, axis=1)[0],
271+
lambda x: expand_dims(x[0], axis=0),
272+
),
273+
(
274+
lambda x: expand_dims(x, axis=(1, 3))[0],
275+
lambda x: expand_dims(x[0], axis=(0, 2)),
276+
),
277+
# Slice indexing
278+
(
279+
lambda x: expand_dims(x, axis=1)[1:],
280+
lambda x: expand_dims(x[1:], axis=1),
281+
),
282+
(
283+
lambda x: expand_dims(x, axis=(1, 3))[1:],
284+
lambda x: expand_dims(x[1:], axis=(1, 3)),
285+
),
286+
# Not supported, slice indexing on expanded dimension
287+
(
288+
lambda x: expand_dims(x, axis=0)[1:],
289+
lambda x: expand_dims(x, axis=0)[1:],
290+
),
291+
# Mixed indexing
292+
(
293+
lambda x: expand_dims(x, axis=1)[0, :, 1:],
294+
lambda x: expand_dims(x[0, 1:], axis=0),
295+
),
296+
(
297+
lambda x: expand_dims(x, axis=1)[1:, :, 0],
298+
lambda x: expand_dims(x[1:, 0], axis=1),
299+
),
300+
(
301+
lambda x: expand_dims(x, axis=(1, 2))[1:, :, 0],
302+
lambda x: expand_dims(x[1:], axis=1),
303+
),
304+
],
305+
)
306+
def test_local_subtensor_of_expand_dims(original_fn, expected_fn):
307+
rng = np.random.default_rng(232)
308+
x = tensor("x", shape=(5, 3))
309+
x_test = rng.normal(size=x.type.shape)
310+
311+
out = original_fn(x)
312+
expected_opt_out = expected_fn(x)
313+
opt_out = rewrite_graph(out, exclude=["local_uint_constant_indices"])
314+
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
315+
[opt_out, expected_opt_out], print_type=True
316+
)
317+
np.testing.assert_allclose(
318+
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
319+
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
320+
)
321+
322+
259323
def test_local_subtensor_of_alloc():
260324
# DebugMode should detect if something goes wrong.
261325
# test shape combination of odd and event shape.

0 commit comments

Comments
 (0)