Skip to content

Commit 2491bc2

Browse files
committed
Specialize AdvancedSubtensor1 mode for compile time valid indices
1 parent 0f5da80 commit 2491bc2

File tree

2 files changed

+76
-27
lines changed

2 files changed

+76
-27
lines changed

pytensor/tensor/subtensor.py

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2120,16 +2120,12 @@ def make_node(self, x, ilist):
21202120
out_shape = (ilist_.type.shape[0], *x_.type.shape[1:])
21212121
return Apply(self, [x_, ilist_], [TensorType(dtype=x.dtype, shape=out_shape)()])
21222122

2123-
def perform(self, node, inp, out_):
2123+
def perform(self, node, inp, output_storage):
21242124
x, i = inp
2125-
(out,) = out_
2126-
# Copy always implied by numpy advanced indexing semantic.
2127-
if out[0] is not None and out[0].shape == (len(i),) + x.shape[1:]:
2128-
o = out[0]
2129-
else:
2130-
o = None
21312125

2132-
out[0] = x.take(i, axis=0, out=o)
2126+
# Numpy take is always slower when out is provided
2127+
# https://github.com/numpy/numpy/issues/28636
2128+
output_storage[0][0] = x.take(i, axis=0, out=None)
21332129

21342130
def connection_pattern(self, node):
21352131
rval = [[True], *([False] for _ in node.inputs[1:])]
@@ -2174,42 +2170,70 @@ def c_code(self, node, name, input_names, output_names, sub):
21742170
"c_code defined for AdvancedSubtensor1, not for child class",
21752171
type(self),
21762172
)
2173+
x, idxs = node.inputs
2174+
shape0 = x.type.shape[0]
2175+
if (
2176+
shape0 is not None
2177+
and isinstance(idxs, Constant)
2178+
and (
2179+
(idxs.data.max() < shape0)
2180+
and ((idxs.data.min() >= 0) or (idxs.data.min() > -shape0))
2181+
)
2182+
):
2183+
# We can know ahead of time that all indices are valid, so we can use a faster mode
2184+
mode = "NPY_WRAP" # This seems to be faster than NPY_CLIP
2185+
else:
2186+
mode = "NPY_RAISE"
21772187
a_name, i_name = input_names[0], input_names[1]
21782188
output_name = output_names[0]
21792189
fail = sub["fail"]
2180-
return f"""
2181-
if ({output_name} != NULL) {{
2182-
npy_intp nd, i, *shape;
2183-
nd = PyArray_NDIM({a_name}) + PyArray_NDIM({i_name}) - 1;
2184-
if (PyArray_NDIM({output_name}) != nd) {{
2190+
if mode == "NPY_RAISE":
2191+
# numpy_take always makes an intermediate copy if NPY_RAISE which is slower than just allocating a new buffer
2192+
# We can remove this special case after https://github.com/numpy/numpy/issues/28636
2193+
manage_pre_allocated_out = f"""
2194+
if ({output_name} != NULL) {{
2195+
// Numpy TakeFrom is always slower when copying
2196+
// https://github.com/numpy/numpy/issues/28636
21852197
Py_CLEAR({output_name});
21862198
}}
2187-
else {{
2188-
shape = PyArray_DIMS({output_name});
2189-
for (i = 0; i < PyArray_NDIM({i_name}); i++) {{
2190-
if (shape[i] != PyArray_DIMS({i_name})[i]) {{
2191-
Py_CLEAR({output_name});
2192-
break;
2193-
}}
2199+
"""
2200+
else:
2201+
manage_pre_allocated_out = f"""
2202+
if ({output_name} != NULL) {{
2203+
npy_intp nd = PyArray_NDIM({a_name}) + PyArray_NDIM({i_name}) - 1;
2204+
if (PyArray_NDIM({output_name}) != nd) {{
2205+
Py_CLEAR({output_name});
21942206
}}
2195-
if ({output_name} != NULL) {{
2196-
for (; i < nd; i++) {{
2197-
if (shape[i] != PyArray_DIMS({a_name})[
2198-
i-PyArray_NDIM({i_name})+1]) {{
2207+
else {{
2208+
int i;
2209+
npy_intp* shape = PyArray_DIMS({output_name});
2210+
for (i = 0; i < PyArray_NDIM({i_name}); i++) {{
2211+
if (shape[i] != PyArray_DIMS({i_name})[i]) {{
21992212
Py_CLEAR({output_name});
22002213
break;
22012214
}}
22022215
}}
2216+
if ({output_name} != NULL) {{
2217+
for (; i < nd; i++) {{
2218+
if (shape[i] != PyArray_DIMS({a_name})[i-PyArray_NDIM({i_name})+1]) {{
2219+
Py_CLEAR({output_name});
2220+
break;
2221+
}}
2222+
}}
2223+
}}
22032224
}}
22042225
}}
2205-
}}
2226+
"""
2227+
2228+
return f"""
2229+
{manage_pre_allocated_out}
22062230
{output_name} = (PyArrayObject*)PyArray_TakeFrom(
2207-
{a_name}, (PyObject*){i_name}, 0, {output_name}, NPY_RAISE);
2231+
{a_name}, (PyObject*){i_name}, 0, {output_name}, {mode});
22082232
if ({output_name} == NULL) {fail};
22092233
"""
22102234

22112235
def c_code_cache_version(self):
2212-
return (4,)
2236+
return (5,)
22132237

22142238

22152239
advanced_subtensor1 = AdvancedSubtensor1()

tests/tensor/test_subtensor.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3003,3 +3003,28 @@ def test_flip(size: tuple[int]):
30033003
z = flip(x_pt, axis=list(axes))
30043004
f = pytensor.function([x_pt], z, mode="FAST_COMPILE")
30053005
np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL)
3006+
3007+
3008+
class TestBenchmarks:
3009+
@pytest.mark.parametrize(
3010+
"static_shape", (False, True), ids=lambda x: f"static_shape={x}"
3011+
)
3012+
@pytest.mark.parametrize("gc", (False, True), ids=lambda x: f"gc={x}")
3013+
def test_advanced_subtensor1(self, static_shape, gc, benchmark):
3014+
x = vector("x", shape=(85 if static_shape else None,))
3015+
3016+
x_values = np.random.normal(size=(85,))
3017+
idxs_values = np.arange(85).repeat(11)
3018+
3019+
# With static shape and constant indices we know all idxs are valid
3020+
# And can use faster mode in numpy.take
3021+
out = x[idxs_values]
3022+
3023+
fn = pytensor.function(
3024+
[x],
3025+
pytensor.Out(out, borrow=True),
3026+
on_unused_input="ignore",
3027+
trust_input=True,
3028+
)
3029+
fn.vm.allow_gc = gc
3030+
benchmark(fn, x_values, idxs_values)

0 commit comments

Comments
 (0)