Skip to content

Commit 7c3d451

Browse files
committed
Specialized C-impl for vector AdvancedIncSubtensor1
1 parent 2491bc2 commit 7c3d451

File tree

2 files changed

+125
-5
lines changed

2 files changed

+125
-5
lines changed

pytensor/tensor/subtensor.py

Lines changed: 99 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2320,6 +2320,9 @@ def copy_of_x(self, x):
23202320
NPY_ARRAY_ENSURECOPY, NULL)"""
23212321

23222322
def c_support_code(self, **kwargs):
2323+
if numpy_version < "1.8.0" or using_numpy_2:
2324+
return None
2325+
23232326
types = [
23242327
"npy_" + t
23252328
for t in [
@@ -2510,15 +2513,105 @@ def gen_num(typen):
25102513
return code
25112514

25122515
def c_code(self, node, name, input_names, output_names, sub):
2513-
if numpy_version < "1.8.0" or using_numpy_2:
2514-
raise NotImplementedError
2515-
25162516
x, y, idx = input_names
2517-
out = output_names[0]
2517+
[out] = output_names
25182518
copy_of_x = self.copy_of_x(x)
25192519
params = sub["params"]
25202520
fail = sub["fail"]
25212521

2522+
x_, y_, idx_ = node.inputs
2523+
y_dtype = y_.type.dtype_specs()[1]
2524+
idx_dtype = idx_.type.dtype_specs()[1]
2525+
out_dtype = node.outputs[0].type.dtype_specs()[1]
2526+
y_bcast = y_.type.broadcastable != idx_.type.broadcastable
2527+
if (
2528+
x_.type.ndim == 1
2529+
and x_.type.dtype not in complex_dtypes
2530+
and not y_bcast
2531+
and y_.type.dtype not in complex_dtypes
2532+
):
2533+
# Simple implementation for vector x, y cases
2534+
idx_may_be_neg = not (isinstance(idx_, Constant) and idx_.data.min() >= 0)
2535+
shape0 = x_.type.shape[0]
2536+
idx_may_be_invalid = not (
2537+
shape0 is not None
2538+
and isinstance(idx_, Constant)
2539+
and (idx_.data.min() > 0 or idx_.data.min() >= -shape0)
2540+
and (idx_.data.max() < 0 or idx_.data.max() < shape0)
2541+
)
2542+
# This is used to make sure that when we trust the indices to be valid
2543+
# we are not fooled by a wrong static shape
2544+
unexpected_shape0 = (
2545+
f"PyArray_SHAPE({x})[0] != {shape0}" if shape0 is not None else "0"
2546+
)
2547+
2548+
op = "=" if self.set_instead_of_inc else "+="
2549+
code = f"""
2550+
if ({params}->inplace)
2551+
{{
2552+
if ({x} != {out})
2553+
{{
2554+
Py_XDECREF({out});
2555+
Py_INCREF({x});
2556+
{out} = {x};
2557+
}}
2558+
}}
2559+
else
2560+
{{
2561+
Py_XDECREF({out});
2562+
{out} = {copy_of_x};
2563+
if (!{out}) {{
2564+
// Exception already set
2565+
{fail}
2566+
}}
2567+
}}
2568+
2569+
if ((PyArray_NDIM({out}) != 1) || ({unexpected_shape0})) {{
2570+
PyErr_SetString(PyExc_ValueError, "Input x to AdvancedIncSubtensor1 does not have right shape or ndim");
2571+
{fail}
2572+
}}
2573+
if (PyArray_NDIM({idx}) != 1) {{
2574+
PyErr_SetString(PyExc_ValueError, "Input idx to AdvancedIncSubtensor1 ndim != 1");
2575+
{fail}
2576+
}}
2577+
if ((PyArray_NDIM({y}) != 1) || (PyArray_SHAPE({y})[0] != PyArray_SHAPE({idx})[0])) {{
2578+
PyErr_SetString(PyExc_ValueError, "Input y to AdvancedIncSubtensor1 does not have right shape or ndim");
2579+
{fail}
2580+
}}
2581+
2582+
{{
2583+
npy_intp out_shape0 = PyArray_SHAPE({out})[0];
2584+
{out_dtype}* out_data = ({out_dtype}*)PyArray_DATA({out});
2585+
{y_dtype}* y_data = ({y_dtype}*)PyArray_DATA({y});
2586+
{idx_dtype}* idx_data = ({idx_dtype}*)PyArray_DATA({idx});
2587+
npy_intp n = PyArray_SHAPE({idx})[0];
2588+
npy_intp out_jump = PyArray_STRIDES({out})[0] / PyArray_ITEMSIZE({out});
2589+
npy_intp y_jump = PyArray_STRIDES({y})[0] / PyArray_ITEMSIZE({y});
2590+
npy_intp idx_jump = PyArray_STRIDES({idx})[0] / PyArray_ITEMSIZE({idx});
2591+
2592+
for(int i = 0; i < n; i++){{
2593+
{idx_dtype} idx = idx_data[i * idx_jump];
2594+
if ({int(idx_may_be_neg)}){{
2595+
if (idx < 0) {{
2596+
idx += out_shape0;
2597+
}}
2598+
}}
2599+
if ({int(idx_may_be_invalid)}){{
2600+
if ((idx < 0) || (idx >= out_shape0)) {{
2601+
PyErr_Format(PyExc_IndexError,"index out of bounds");
2602+
{fail}
2603+
}}
2604+
}}
2605+
out_data[idx * out_jump] {op} y_data[i * y_jump];
2606+
}}
2607+
2608+
}}
2609+
"""
2610+
return code
2611+
2612+
if numpy_version < "1.8.0" or using_numpy_2:
2613+
raise NotImplementedError
2614+
25222615
return f"""
25232616
PyObject* rval = NULL;
25242617
if ({params}->inplace)
@@ -2546,7 +2639,8 @@ def c_code(self, node, name, input_names, output_names, sub):
25462639
"""
25472640

25482641
def c_code_cache_version(self):
2549-
return (8,)
2642+
return None
2643+
return (9,)
25502644

25512645
def perform(self, node, inp, out_):
25522646
x, y, idx = inp

tests/tensor/test_subtensor.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3028,3 +3028,29 @@ def test_advanced_subtensor1(self, static_shape, gc, benchmark):
30283028
)
30293029
fn.vm.allow_gc = gc
30303030
benchmark(fn, x_values, idxs_values)
3031+
3032+
@pytest.mark.parametrize(
3033+
"static_shape", (False, True), ids=lambda x: f"static_shape={x}"
3034+
)
3035+
@pytest.mark.parametrize("gc", (False, True), ids=lambda x: f"gc={x}")
3036+
@pytest.mark.parametrize("func", (inc_subtensor, set_subtensor))
3037+
def test_advanced_incsubtensor1(self, func, static_shape, gc, benchmark):
3038+
x = vector("x", shape=(85 if static_shape else None,))
3039+
x_values = np.zeros((85,))
3040+
buffer = ptb.zeros_like(x)
3041+
y_values = np.random.normal(size=(85 * 11,))
3042+
idxs_values = np.arange(85).repeat(11)
3043+
3044+
# With static shape and constant indices we know all idxs are valid
3045+
# Reuse same buffer of zeros, to check we rather allocate twice than copy inside IncSubtensor
3046+
out1 = func(buffer[idxs_values], y_values)
3047+
out2 = func(buffer[idxs_values[::-1]], y_values)
3048+
3049+
fn = pytensor.function(
3050+
[x],
3051+
[pytensor.Out(out1, borrow=True), pytensor.Out(out2, borrow=True)],
3052+
on_unused_input="ignore",
3053+
trust_input=True,
3054+
)
3055+
fn.vm.allow_gc = gc
3056+
benchmark(fn, x_values)

0 commit comments

Comments
 (0)