Skip to content

Commit 8912887

Browse files
committed
Specialized C-impl for vector AdvancedIncSubtensor1
1 parent b31e924 commit 8912887

File tree

2 files changed

+119
-5
lines changed

2 files changed

+119
-5
lines changed

pytensor/tensor/subtensor.py

Lines changed: 93 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2326,6 +2326,9 @@ def copy_of_x(self, x):
23262326
NPY_ARRAY_ENSURECOPY, NULL)"""
23272327

23282328
def c_support_code(self, **kwargs):
2329+
if numpy_version < "1.8.0" or using_numpy_2:
2330+
return None
2331+
23292332
types = [
23302333
"npy_" + t
23312334
for t in [
@@ -2516,15 +2519,100 @@ def gen_num(typen):
25162519
return code
25172520

25182521
def c_code(self, node, name, input_names, output_names, sub):
2519-
if numpy_version < "1.8.0" or using_numpy_2:
2520-
raise NotImplementedError
2521-
25222522
x, y, idx = input_names
2523-
out = output_names[0]
2523+
[out] = output_names
25242524
copy_of_x = self.copy_of_x(x)
25252525
params = sub["params"]
25262526
fail = sub["fail"]
25272527

2528+
x_, y_, idx_ = node.inputs
2529+
y_dtype = y_.type.dtype_specs()[1]
2530+
idx_dtype = idx_.type.dtype_specs()[1]
2531+
out_dtype = node.outputs[0].type.dtype_specs()[1]
2532+
y_bcast = y_.type.broadcastable != idx_.type.broadcastable
2533+
if (
2534+
x_.type.ndim == 1
2535+
and x_.type.dtype not in complex_dtypes
2536+
and not y_bcast
2537+
and y_.type.dtype not in complex_dtypes
2538+
):
2539+
# Simple implementation for vector x, y cases
2540+
idx_may_be_neg = not (isinstance(idx_, Constant) and idx_.data.min() >= 0)
2541+
idx_may_be_invalid = AdvancedSubtensor1._idx_may_be_invalid(x_, idx_)
2542+
shape0 = x_.type.shape[0]
2543+
# This is used to make sure that when we trust the indices to be valid
2544+
# we are not fooled by a wrong static shape
2545+
unexpected_shape0 = (
2546+
f"PyArray_SHAPE({x})[0] != {shape0}" if shape0 is not None else "0"
2547+
)
2548+
2549+
op = "=" if self.set_instead_of_inc else "+="
2550+
code = f"""
2551+
if ({params}->inplace)
2552+
{{
2553+
if ({x} != {out})
2554+
{{
2555+
Py_XDECREF({out});
2556+
Py_INCREF({x});
2557+
{out} = {x};
2558+
}}
2559+
}}
2560+
else
2561+
{{
2562+
Py_XDECREF({out});
2563+
{out} = {copy_of_x};
2564+
if (!{out}) {{
2565+
// Exception already set
2566+
{fail}
2567+
}}
2568+
}}
2569+
2570+
if ((PyArray_NDIM({out}) != 1) || ({unexpected_shape0})) {{
2571+
PyErr_SetString(PyExc_ValueError, "Input x to AdvancedIncSubtensor1 does not have right shape or ndim");
2572+
{fail}
2573+
}}
2574+
if (PyArray_NDIM({idx}) != 1) {{
2575+
PyErr_SetString(PyExc_ValueError, "Input idx to AdvancedIncSubtensor1 ndim != 1");
2576+
{fail}
2577+
}}
2578+
if ((PyArray_NDIM({y}) != 1) || (PyArray_SHAPE({y})[0] != PyArray_SHAPE({idx})[0])) {{
2579+
PyErr_SetString(PyExc_ValueError, "Input y to AdvancedIncSubtensor1 does not have right shape or ndim");
2580+
{fail}
2581+
}}
2582+
2583+
{{
2584+
npy_intp out_shape0 = PyArray_SHAPE({out})[0];
2585+
{out_dtype}* out_data = ({out_dtype}*)PyArray_DATA({out});
2586+
{y_dtype}* y_data = ({y_dtype}*)PyArray_DATA({y});
2587+
{idx_dtype}* idx_data = ({idx_dtype}*)PyArray_DATA({idx});
2588+
npy_intp n = PyArray_SHAPE({idx})[0];
2589+
npy_intp out_jump = PyArray_STRIDES({out})[0] / PyArray_ITEMSIZE({out});
2590+
npy_intp y_jump = PyArray_STRIDES({y})[0] / PyArray_ITEMSIZE({y});
2591+
npy_intp idx_jump = PyArray_STRIDES({idx})[0] / PyArray_ITEMSIZE({idx});
2592+
2593+
for(int i = 0; i < n; i++){{
2594+
{idx_dtype} idx = idx_data[i * idx_jump];
2595+
if ({int(idx_may_be_neg)}){{
2596+
if (idx < 0) {{
2597+
idx += out_shape0;
2598+
}}
2599+
}}
2600+
if ({int(idx_may_be_invalid)}){{
2601+
if ((idx < 0) || (idx >= out_shape0)) {{
2602+
PyErr_Format(PyExc_IndexError,"index out of bounds");
2603+
{fail}
2604+
}}
2605+
}}
2606+
out_data[idx * out_jump] {op} y_data[i * y_jump];
2607+
}}
2608+
2609+
}}
2610+
"""
2611+
return code
2612+
2613+
if numpy_version < "1.8.0" or using_numpy_2:
2614+
raise NotImplementedError
2615+
25282616
return f"""
25292617
PyObject* rval = NULL;
25302618
if ({params}->inplace)
@@ -2552,7 +2640,7 @@ def c_code(self, node, name, input_names, output_names, sub):
25522640
"""
25532641

25542642
def c_code_cache_version(self):
2555-
return (8,)
2643+
return (9,)
25562644

25572645
def perform(self, node, inp, out_):
25582646
x, y, idx = inp

tests/tensor/test_subtensor.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3028,3 +3028,29 @@ def test_advanced_subtensor1(self, static_shape, gc, benchmark):
30283028
)
30293029
fn.vm.allow_gc = gc
30303030
benchmark(fn, x_values, idxs_values)
3031+
3032+
@pytest.mark.parametrize(
3033+
"static_shape", (False, True), ids=lambda x: f"static_shape={x}"
3034+
)
3035+
@pytest.mark.parametrize("gc", (False, True), ids=lambda x: f"gc={x}")
3036+
@pytest.mark.parametrize("func", (inc_subtensor, set_subtensor))
3037+
def test_advanced_incsubtensor1(self, func, static_shape, gc, benchmark):
3038+
x = vector("x", shape=(85 if static_shape else None,))
3039+
x_values = np.zeros((85,))
3040+
buffer = ptb.zeros_like(x)
3041+
y_values = np.random.normal(size=(85 * 11,))
3042+
idxs_values = np.arange(85).repeat(11)
3043+
3044+
# With static shape and constant indices we know all idxs are valid
3045+
# Reuse same buffer of zeros, to check we rather allocate twice than copy inside IncSubtensor
3046+
out1 = func(buffer[idxs_values], y_values)
3047+
out2 = func(buffer[idxs_values[::-1]], y_values)
3048+
3049+
fn = pytensor.function(
3050+
[x],
3051+
[pytensor.Out(out1, borrow=True), pytensor.Out(out2, borrow=True)],
3052+
on_unused_input="ignore",
3053+
trust_input=True,
3054+
)
3055+
fn.vm.allow_gc = gc
3056+
benchmark(fn, x_values)

0 commit comments

Comments
 (0)