Skip to content

Commit f5a13f2

Browse files
committed
Lift Subtensor over expand_dims
1 parent 9e18d3c commit f5a13f2

File tree

3 files changed

+162
-19
lines changed

3 files changed

+162
-19
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
indices_from_subtensor,
7676
)
7777
from pytensor.tensor.type import TensorType, integer_dtypes
78-
from pytensor.tensor.type_other import NoneTypeT, SliceConstant, SliceType
78+
from pytensor.tensor.type_other import NoneTypeT, SliceType
7979
from pytensor.tensor.variable import TensorConstant, TensorVariable
8080

8181

@@ -155,19 +155,21 @@ def transform_take(a, indices, axis):
155155

156156
def is_full_slice(x):
157157
"""Determine if `x` is a ``slice(None)`` or a symbolic equivalent."""
158-
if (
159-
(isinstance(x, slice) and x == slice(None))
160-
or (isinstance(x, SliceConstant) and x.value == slice(None))
161-
or (
162-
not isinstance(x, SliceConstant)
163-
and isinstance(getattr(x, "type", None), SliceType)
164-
and x.owner is not None
165-
and all(
166-
isinstance(getattr(i, "type", None), NoneTypeT) for i in x.owner.inputs
167-
)
168-
)
169-
):
170-
return True
158+
if isinstance(x, slice):
159+
return x == slice(None)
160+
161+
if isinstance(x, Variable) and isinstance(x.type, SliceType):
162+
if x.owner is None:
163+
if isinstance(x, Constant):
164+
return x.data == slice(None)
165+
else:
166+
# Root slice variable
167+
return False
168+
169+
# Symbolic MakeSlice
170+
# Ignores start = 0, step = 1 cases
171+
return all(isinstance(i.type, NoneTypeT) for i in x.owner.inputs)
172+
171173
return False
172174

173175

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,19 @@
1111
MakeVector,
1212
alloc,
1313
as_tensor,
14+
expand_dims,
1415
get_underlying_scalar_constant_value,
1516
register_infer_shape,
1617
)
17-
from pytensor.tensor.elemwise import Elemwise
18+
from pytensor.tensor.elemwise import DimShuffle, Elemwise
1819
from pytensor.tensor.exceptions import NotScalarConstantError
1920
from pytensor.tensor.math import Dot, ceil_intdiv, dot
2021
from pytensor.tensor.rewriting.basic import (
2122
register_canonicalize,
2223
register_specialize,
2324
register_stabilize,
2425
)
25-
from pytensor.tensor.rewriting.subtensor import register_useless
26+
from pytensor.tensor.rewriting.subtensor import is_full_slice, register_useless
2627
from pytensor.tensor.shape import (
2728
Shape,
2829
SpecifyShape,
@@ -37,6 +38,7 @@
3738
get_canonical_form_slice,
3839
get_constant_idx,
3940
get_idx_list,
41+
indices_from_subtensor,
4042
)
4143
from pytensor.tensor.type import TensorType
4244
from pytensor.tensor.type_other import SliceType
@@ -204,6 +206,80 @@ def local_subtensor_lift(fgraph, node):
204206
return [rbcast_subt_x]
205207

206208

209+
@register_canonicalize("shape_unsafe")
210+
@register_specialize("shape_unsafe")
211+
@node_rewriter([Subtensor])
212+
def local_subtensor_of_expand_dims(fgraph, node):
213+
"""Lift a Subtensor through a DimShuffle that only expands dims.
214+
215+
expand_dims(x, axis=0)[0] -> x
216+
expand_dims(x, axis=0)[:, 0] -> expand_dims(x[0], axis=0)
217+
expand_dims(x, axis=2)[0] -> expand_dims(x[0], axis=1)
218+
219+
This goes beyond `local_subtensor_remove_broadcastable_index` which
220+
simply removes useless subtensors on broadcastable dimensions.
221+
"""
222+
ds, *idx = node.inputs
223+
224+
if not (ds.owner and isinstance(ds.owner.op, DimShuffle)):
225+
return None
226+
227+
ds_op = ds.owner.op
228+
229+
if not ds_op.is_expand_dims:
230+
return None
231+
232+
expanded_axes = ds_op.augment
233+
[x] = ds.owner.inputs
234+
235+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
236+
237+
# Keep indexes for the original dimensions, and drop indexes for the expanded dimensions when safe
238+
new_idxs = []
239+
for i, idx_item in enumerate(idx_tuple):
240+
if i in expanded_axes:
241+
if isinstance(idx_item, slice):
242+
# Slice could be keeping or dropping this dimension
243+
if is_full_slice(idx_item):
244+
# A None slice, always keeps the dimension.
245+
# We skip the index, and later introduce the needed expand_dim
246+
continue
247+
else:
248+
# Other slices could keep or drop the dimension.
249+
# Get out instead o trying to figure out which case it is
250+
return None
251+
else:
252+
# Integer indexing can only drop the dimension (if it's a valid graph)
253+
# We can just drop the index and avoid expanding the dimension
254+
# This is why this rewrite is tagged with "shape_unsafe"
255+
continue
256+
else:
257+
# Keep indexes for non-expanded dimensions
258+
new_idxs.append(idx_item)
259+
260+
[old_out] = node.outputs
261+
out = x[tuple(new_idxs)]
262+
copy_stack_trace(old_out, out)
263+
264+
if out.type.broadcastable != old_out.type.broadcastable:
265+
# Re-introduce needed new dimensions (corresponding to full slices on the original expanded dimensions)
266+
# If out.type.broadcastable == (False) and old_out.type.broadcastable == (True, False, True)
267+
# then axis = (0, 2)
268+
old_bcast = list(old_out.type.broadcastable)
269+
expanded_bcast = list(out.type.broadcastable)
270+
axis = []
271+
i = 0
272+
while i < len(old_bcast):
273+
if i == len(expanded_bcast) or expanded_bcast[i] != old_bcast[i]:
274+
expanded_bcast.insert(i, True)
275+
axis.append(i)
276+
i += 1
277+
out = expand_dims(out, axis=axis)
278+
copy_stack_trace(old_out, out)
279+
280+
return [out]
281+
282+
207283
@register_infer_shape
208284
@register_useless
209285
@register_canonicalize

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33
import unittest_tools as utt
4+
from tensor.rewriting.test_subtensor import mode_opt
45

56
from pytensor import (
67
Mode,
@@ -23,7 +24,9 @@
2324
Type,
2425
rewrite_graph,
2526
)
27+
from pytensor.graph.basic import equal_computations
2628
from pytensor.graph.rewriting.basic import check_stack_trace
29+
from pytensor.printing import debugprint
2730
from pytensor.tensor import (
2831
add,
2932
exp,
@@ -42,7 +45,7 @@
4245
tensor3,
4346
vector,
4447
)
45-
from pytensor.tensor.basic import MakeVector, make_vector
48+
from pytensor.tensor.basic import MakeVector, expand_dims, make_vector
4649
from pytensor.tensor.elemwise import DimShuffle, Elemwise
4750
from pytensor.tensor.rewriting.subtensor_lift import (
4851
local_subtensor_make_vector,
@@ -58,6 +61,9 @@
5861
mode_opt = get_mode(mode_opt)
5962

6063

64+
NO_OPTIMIZATION_MODE = Mode(linker="py", optimizer=None)
65+
66+
6167
class TestLocalSubtensorLift:
6268
def test_basic(self):
6369
# basic test that the Op works
@@ -139,8 +145,8 @@ def test_basic_4(self):
139145
assert check_stack_trace(f, ops_to_check="all")
140146

141147
prog = f.maker.fgraph.toposort()
142-
assert isinstance(prog[0].op, DimShuffle)
143-
assert isinstance(prog[1].op, Subtensor)
148+
assert isinstance(prog[0].op, Subtensor)
149+
assert isinstance(prog[1].op, DimShuffle)
144150
assert prog[2].op == exp
145151
assert len(prog) == 3
146152
f([4, 5]) # let debugmode test something
@@ -261,6 +267,65 @@ def test_basic_8(self):
261267
assert (f4(zval) == zval[:, 3, 0]).all()
262268

263269

270+
@pytest.mark.parametrize(
271+
"original_fn, expected_fn",
272+
[
273+
# Integer indexing
274+
(lambda x: expand_dims(x, axis=0)[0], lambda x: x),
275+
(
276+
lambda x: expand_dims(x, axis=1)[0],
277+
lambda x: expand_dims(x[0], axis=0),
278+
),
279+
(
280+
lambda x: expand_dims(x, axis=(1, 3))[0],
281+
lambda x: expand_dims(x[0], axis=(0, 2)),
282+
),
283+
# Slice indexing
284+
(
285+
lambda x: expand_dims(x, axis=1)[1:],
286+
lambda x: expand_dims(x[1:], axis=1),
287+
),
288+
(
289+
lambda x: expand_dims(x, axis=(1, 3))[1:],
290+
lambda x: expand_dims(x[1:], axis=(1, 3)),
291+
),
292+
# Not supported, slice indexing on expanded dimension
293+
(
294+
lambda x: expand_dims(x, axis=0)[1:],
295+
lambda x: expand_dims(x, axis=0)[1:],
296+
),
297+
# Mixed indexing
298+
(
299+
lambda x: expand_dims(x, axis=1)[0, :, 1:],
300+
lambda x: expand_dims(x[0, 1:], axis=0),
301+
),
302+
(
303+
lambda x: expand_dims(x, axis=1)[1:, :, 0],
304+
lambda x: expand_dims(x[1:, 0], axis=1),
305+
),
306+
(
307+
lambda x: expand_dims(x, axis=(1, 2))[1:, :, 0],
308+
lambda x: expand_dims(x[1:], axis=1),
309+
),
310+
],
311+
)
312+
def test_local_subtensor_of_expand_dims(original_fn, expected_fn):
313+
rng = np.random.default_rng(232)
314+
x = tensor("x", shape=(5, 3))
315+
x_test = rng.normal(size=x.type.shape)
316+
317+
out = original_fn(x)
318+
expected_opt_out = expected_fn(x)
319+
opt_out = rewrite_graph(out, exclude=["local_uint_constant_indices"])
320+
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
321+
[opt_out, expected_opt_out], print_type=True
322+
)
323+
np.testing.assert_allclose(
324+
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
325+
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
326+
)
327+
328+
264329
def test_local_subtensor_of_alloc():
265330
# DebugMode should detect if something goes wrong.
266331
# test shape combination of odd and event shape.

0 commit comments

Comments
 (0)