Skip to content

Commit 359d42c

Browse files
committed
Specialized C-impl for vector AdvancedIncSubtensor1
Also add checks for runtime broadcast
1 parent 1c8ac55 commit 359d42c

File tree

5 files changed

+186
-15
lines changed

5 files changed

+186
-15
lines changed

pytensor/link/jax/dispatch/subtensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
6767
if len(indices) == 1:
6868
indices = indices[0]
6969

70+
if isinstance(op, AdvancedIncSubtensor1):
71+
op._check_runtime_broadcasting(x, y, indices)
72+
7073
return jax_fn(x, indices, y)
7174

7275
return incsubtensor

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,11 +287,11 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
287287
inplace = op.inplace
288288
set_instead_of_inc = op.set_instead_of_inc
289289
x, vals, idxs = node.inputs
290-
# TODO: Add explicit expand_dims in make_node so we don't need to worry about this here
291-
broadcast = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0]
290+
broadcast_with_index = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0]
291+
# TODO: Add runtime_broadcast check
292292

293293
if set_instead_of_inc:
294-
if broadcast:
294+
if broadcast_with_index:
295295

296296
@numba_njit(boundscheck=True)
297297
def advancedincsubtensor1_inplace(x, val, idxs):
@@ -318,7 +318,7 @@ def advancedincsubtensor1_inplace(x, vals, idxs):
318318
x[idx] = val
319319
return x
320320
else:
321-
if broadcast:
321+
if broadcast_with_index:
322322

323323
@numba_njit(boundscheck=True)
324324
def advancedincsubtensor1_inplace(x, val, idxs):

pytensor/link/pytorch/dispatch/subtensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
109109

110110
def adv_set_subtensor(x, y, *indices):
111111
check_negative_steps(indices)
112+
if isinstance(op, AdvancedIncSubtensor1):
113+
op._check_runtime_broadcasting(x, y, indices)
112114
if not inplace:
113115
x = x.clone()
114116
x[indices] = y.type_as(x)
@@ -120,6 +122,8 @@ def adv_set_subtensor(x, y, *indices):
120122

121123
def adv_inc_subtensor_no_duplicates(x, y, *indices):
122124
check_negative_steps(indices)
125+
if isinstance(op, AdvancedIncSubtensor1):
126+
op._check_runtime_broadcasting(x, y, indices)
123127
if not inplace:
124128
x = x.clone()
125129
x[indices] += y.type_as(x)

pytensor/tensor/subtensor.py

Lines changed: 128 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2257,6 +2257,12 @@ class AdvancedIncSubtensor1(COp):
22572257
check_input = False
22582258
params_type = ParamsType(inplace=ps.bool, set_instead_of_inc=ps.bool)
22592259

2260+
_runtime_broadcast_error_msg = (
2261+
"Runtime broadcasting not allowed. "
2262+
"AdvancedIncSubtensor1 was asked to broadcast the second input (y) along a dimension that was not marked as broadcastable. "
2263+
"If broadcasting was intended, use `specify_broadcastable` on the relevant dimension(s)."
2264+
)
2265+
22602266
def __init__(self, inplace=False, set_instead_of_inc=False):
22612267
self.inplace = bool(inplace)
22622268
self.set_instead_of_inc = bool(set_instead_of_inc)
@@ -2328,6 +2334,9 @@ def copy_of_x(self, x):
23282334
NPY_ARRAY_ENSURECOPY, NULL)"""
23292335

23302336
def c_support_code(self, **kwargs):
2337+
if numpy_version < "1.8.0" or using_numpy_2:
2338+
return None
2339+
23312340
types = [
23322341
"npy_" + t
23332342
for t in [
@@ -2518,15 +2527,104 @@ def gen_num(typen):
25182527
return code
25192528

25202529
def c_code(self, node, name, input_names, output_names, sub):
2521-
if numpy_version < "1.8.0" or using_numpy_2:
2522-
raise NotImplementedError
2523-
25242530
x, y, idx = input_names
2525-
out = output_names[0]
2531+
[out] = output_names
25262532
copy_of_x = self.copy_of_x(x)
25272533
params = sub["params"]
25282534
fail = sub["fail"]
25292535

2536+
x_, y_, idx_ = node.inputs
2537+
y_dtype = y_.type.dtype_specs()[1]
2538+
idx_dtype = idx_.type.dtype_specs()[1]
2539+
out_dtype = node.outputs[0].type.dtype_specs()[1]
2540+
y_bcast = y_.type.broadcastable != idx_.type.broadcastable
2541+
if (
2542+
x_.type.ndim == 1
2543+
and x_.type.dtype not in complex_dtypes
2544+
and not y_bcast
2545+
and y_.type.dtype not in complex_dtypes
2546+
):
2547+
# Simple implementation for vector x, y cases
2548+
idx_may_be_neg = not (isinstance(idx_, Constant) and idx_.data.min() >= 0)
2549+
idx_may_be_invalid = AdvancedSubtensor1._idx_may_be_invalid(x_, idx_)
2550+
shape0 = x_.type.shape[0]
2551+
# This is used to make sure that when we trust the indices to be valid
2552+
# we are not fooled by a wrong static shape
2553+
unexpected_shape0 = (
2554+
f"PyArray_SHAPE({x})[0] != {shape0}" if shape0 is not None else "0"
2555+
)
2556+
2557+
op = "=" if self.set_instead_of_inc else "+="
2558+
code = f"""
2559+
if ({params}->inplace)
2560+
{{
2561+
if ({x} != {out})
2562+
{{
2563+
Py_XDECREF({out});
2564+
Py_INCREF({x});
2565+
{out} = {x};
2566+
}}
2567+
}}
2568+
else
2569+
{{
2570+
Py_XDECREF({out});
2571+
{out} = {copy_of_x};
2572+
if (!{out}) {{
2573+
// Exception already set
2574+
{fail}
2575+
}}
2576+
}}
2577+
2578+
if ((PyArray_NDIM({out}) != 1) || ({unexpected_shape0})) {{
2579+
PyErr_SetString(PyExc_ValueError, "AdvancedIncSubtensor1: fist input (x) does not have right shape or ndim");
2580+
{fail}
2581+
}}
2582+
if (PyArray_NDIM({idx}) != 1) {{
2583+
PyErr_SetString(PyExc_ValueError, "AdvancedIncSubtensor1: indices ndim != 1");
2584+
{fail}
2585+
}}
2586+
if ((PyArray_NDIM({y}) != 1) || (PyArray_SHAPE({y})[0] != PyArray_SHAPE({idx})[0])) {{
2587+
if ((PyArray_NDIM({y}) == 1) && (PyArray_SHAPE({y})[0] == 1)){{
2588+
PyErr_SetString(PyExc_ValueError, "{self._runtime_broadcast_error_msg}");
2589+
}} else {{
2590+
PyErr_SetString(PyExc_ValueError, "AdvancedIncSubtensor1: Shapes of second input (y) and indices do not match");
2591+
}}
2592+
{fail}
2593+
}}
2594+
2595+
{{
2596+
npy_intp out_shape0 = PyArray_SHAPE({out})[0];
2597+
{out_dtype}* out_data = ({out_dtype}*)PyArray_DATA({out});
2598+
{y_dtype}* y_data = ({y_dtype}*)PyArray_DATA({y});
2599+
{idx_dtype}* idx_data = ({idx_dtype}*)PyArray_DATA({idx});
2600+
npy_intp n = PyArray_SHAPE({idx})[0];
2601+
npy_intp out_jump = PyArray_STRIDES({out})[0] / PyArray_ITEMSIZE({out});
2602+
npy_intp y_jump = PyArray_STRIDES({y})[0] / PyArray_ITEMSIZE({y});
2603+
npy_intp idx_jump = PyArray_STRIDES({idx})[0] / PyArray_ITEMSIZE({idx});
2604+
2605+
for(int i = 0; i < n; i++){{
2606+
{idx_dtype} idx = idx_data[i * idx_jump];
2607+
if ({int(idx_may_be_neg)}){{
2608+
if (idx < 0) {{
2609+
idx += out_shape0;
2610+
}}
2611+
}}
2612+
if ({int(idx_may_be_invalid)}){{
2613+
if ((idx < 0) || (idx >= out_shape0)) {{
2614+
PyErr_Format(PyExc_IndexError,"index out of bounds");
2615+
{fail}
2616+
}}
2617+
}}
2618+
out_data[idx * out_jump] {op} y_data[i * y_jump];
2619+
}}
2620+
2621+
}}
2622+
"""
2623+
return code
2624+
2625+
if numpy_version < "1.8.0" or using_numpy_2:
2626+
raise NotImplementedError
2627+
25302628
return f"""
25312629
PyObject* rval = NULL;
25322630
if ({params}->inplace)
@@ -2554,22 +2652,43 @@ def c_code(self, node, name, input_names, output_names, sub):
25542652
"""
25552653

25562654
def c_code_cache_version(self):
2557-
return (8,)
2655+
return (9,)
2656+
2657+
def _check_runtime_broadcasting(self, node, x, y, idx):
2658+
if y.ndim > 0:
2659+
y_pt_bcast = node.inputs[1].broadcastable
2660+
2661+
if not y_pt_bcast[0] and y.shape[0] == 1 and y.shape[0] != idx.shape[0]:
2662+
# Attempting to broadcast with index
2663+
raise ValueError(self._runtime_broadcast_error_msg)
2664+
if any(
2665+
not y_bcast and y_dim == 1 and y_dim != x_dim
2666+
for y_bcast, y_dim, x_dim in zip(
2667+
reversed(y_pt_bcast),
2668+
reversed(y.shape),
2669+
reversed(x.shape),
2670+
strict=False,
2671+
)
2672+
):
2673+
# Attempting to broadcast with buffer
2674+
raise ValueError(self._runtime_broadcast_error_msg)
2675+
2676+
def perform(self, node, inputs, output_storage):
2677+
x, y, idx = inputs
25582678

2559-
def perform(self, node, inp, out_):
2560-
x, y, idx = inp
2561-
(out,) = out_
25622679
if not self.inplace:
25632680
x = x.copy()
25642681

2682+
self._check_runtime_broadcasting(node, x, y, idx)
2683+
25652684
if self.set_instead_of_inc:
25662685
x[idx] = y
25672686
else:
25682687
# In Numpy, `x[idx] += y` doesn't work if the same index is present
25692688
# many times: it does it only once.
25702689
np.add.at(x, idx, y)
25712690

2572-
out[0] = x
2691+
output_storage[0][0] = x
25732692

25742693
def infer_shape(self, fgraph, node, ishapes):
25752694
x, y, ilist = ishapes

tests/tensor/test_subtensor.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,8 +1231,6 @@ def test_advanced1_inc_and_set(self):
12311231
data_num_init = np.arange(data_size, dtype=self.dtype)
12321232
data_num_init = data_num_init.reshape(data_shape)
12331233
inc_shapes = [data_shape[i:] for i in range(0, len(data_shape) + 1)]
1234-
# Test broadcasting of y.
1235-
inc_shapes += [(1,) + inc_shapes[-1][1:]]
12361234
for inc_shape in inc_shapes:
12371235
inc_n_dims = len(inc_shape)
12381236
# We copy the numeric value to be 100% sure there is no
@@ -1341,6 +1339,27 @@ def test_advanced1_inc_and_set(self):
13411339
# you enable the debug code above.
13421340
assert np.allclose(f_out, output_num), (params, f_out, output_num)
13431341

1342+
@pytest.mark.parametrize("func", (inc_subtensor, set_subtensor))
1343+
def test_advannced1_inc_runtime_broadcast(self, func):
1344+
y = matrix("y", dtype="float64", shape=(None, None))
1345+
1346+
x = ptb.zeros((10, 5))
1347+
idxs = np.repeat(np.arange(10), 2)
1348+
out = func(x[idxs], y)
1349+
1350+
f = function([y], out)
1351+
f(np.ones((20, 5))) # Fine
1352+
with pytest.raises(
1353+
ValueError,
1354+
match="Runtime broadcasting not allowed. AdvancedIncSubtensor1 was asked",
1355+
):
1356+
f(np.ones((1, 5)))
1357+
with pytest.raises(
1358+
ValueError,
1359+
match="Runtime broadcasting not allowed. AdvancedIncSubtensor1 was asked",
1360+
):
1361+
f(np.ones((20, 1)))
1362+
13441363
def test_adv_constant_arg(self):
13451364
# Test case provided (and bug detected, gh-607) by John Salvatier
13461365
m = matrix("m")
@@ -3028,3 +3047,29 @@ def test_advanced_subtensor1(self, static_shape, gc, benchmark):
30283047
)
30293048
fn.vm.allow_gc = gc
30303049
benchmark(fn, x_values, idxs_values)
3050+
3051+
@pytest.mark.parametrize(
3052+
"static_shape", (False, True), ids=lambda x: f"static_shape={x}"
3053+
)
3054+
@pytest.mark.parametrize("gc", (False, True), ids=lambda x: f"gc={x}")
3055+
@pytest.mark.parametrize("func", (inc_subtensor, set_subtensor))
3056+
def test_advanced_incsubtensor1(self, func, static_shape, gc, benchmark):
3057+
x = vector("x", shape=(85 if static_shape else None,))
3058+
x_values = np.zeros((85,))
3059+
buffer = ptb.zeros_like(x)
3060+
y_values = np.random.normal(size=(85 * 11,))
3061+
idxs_values = np.arange(85).repeat(11)
3062+
3063+
# With static shape and constant indices we know all idxs are valid
3064+
# Reuse same buffer of zeros, to check we rather allocate twice than copy inside IncSubtensor
3065+
out1 = func(buffer[idxs_values], y_values)
3066+
out2 = func(buffer[idxs_values[::-1]], y_values)
3067+
3068+
fn = pytensor.function(
3069+
[x],
3070+
[pytensor.Out(out1, borrow=True), pytensor.Out(out2, borrow=True)],
3071+
on_unused_input="ignore",
3072+
trust_input=True,
3073+
)
3074+
fn.vm.allow_gc = gc
3075+
benchmark(fn, x_values)

0 commit comments

Comments
 (0)