Skip to content

Commit 6e84408

Browse files
committed
Let numpy methods handle integer size problems in AdvancedSubtensor1
1 parent 0b56ed9 commit 6e84408

File tree

1 file changed

+7
-82
lines changed

1 file changed

+7
-82
lines changed

pytensor/tensor/subtensor.py

Lines changed: 7 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import warnings
44
from collections.abc import Callable, Iterable, Sequence
55
from itertools import chain, groupby
6-
from textwrap import dedent
76
from typing import cast, overload
87

98
import numpy as np
@@ -19,7 +18,7 @@
1918
from pytensor.graph.utils import MethodNotDefined
2019
from pytensor.link.c.op import COp
2120
from pytensor.link.c.params_type import ParamsType
22-
from pytensor.npy_2_compat import npy_2_compat_header, numpy_version, using_numpy_2
21+
from pytensor.npy_2_compat import numpy_version, using_numpy_2
2322
from pytensor.printing import Printer, pprint, set_precedence
2423
from pytensor.scalar.basic import ScalarConstant, ScalarVariable
2524
from pytensor.tensor import (
@@ -2130,24 +2129,6 @@ def perform(self, node, inp, out_):
21302129
else:
21312130
o = None
21322131

2133-
# If i.dtype is more precise than numpy.intp (int32 on 32-bit machines,
2134-
# int64 on 64-bit machines), numpy may raise the following error:
2135-
# TypeError: array cannot be safely cast to required type.
2136-
# We need to check if values in i can fit in numpy.intp, because
2137-
# if they don't, that should be an error (no array can have that
2138-
# many elements on a 32-bit arch).
2139-
if i.dtype != np.intp:
2140-
i_ = np.asarray(i, dtype=np.intp)
2141-
if not np.can_cast(i.dtype, np.intp):
2142-
# Check if there was actually an incorrect conversion
2143-
if np.any(i != i_):
2144-
raise IndexError(
2145-
"index contains values that are bigger "
2146-
"than the maximum array size on this system.",
2147-
i,
2148-
)
2149-
i = i_
2150-
21512132
out[0] = x.take(i, axis=0, out=o)
21522133

21532134
def connection_pattern(self, node):
@@ -2187,16 +2168,6 @@ def infer_shape(self, fgraph, node, ishapes):
21872168
x, ilist = ishapes
21882169
return [ilist + x[1:]]
21892170

2190-
def c_support_code(self, **kwargs):
2191-
# In some versions of numpy, NPY_MIN_INTP is defined as MIN_LONG,
2192-
# which is not defined. It should be NPY_MIN_LONG instead in that case.
2193-
return npy_2_compat_header() + dedent(
2194-
"""\
2195-
#ifndef MIN_LONG
2196-
#define MIN_LONG NPY_MIN_LONG
2197-
#endif"""
2198-
)
2199-
22002171
def c_code(self, node, name, input_names, output_names, sub):
22012172
if self.__class__ is not AdvancedSubtensor1:
22022173
raise MethodNotDefined(
@@ -2207,69 +2178,24 @@ def c_code(self, node, name, input_names, output_names, sub):
22072178
output_name = output_names[0]
22082179
fail = sub["fail"]
22092180
return f"""
2210-
PyArrayObject *indices;
2211-
int i_type = PyArray_TYPE({i_name});
2212-
if (i_type != NPY_INTP) {{
2213-
// Cast {i_name} to NPY_INTP (expected by PyArray_TakeFrom),
2214-
// if all values fit.
2215-
if (!PyArray_CanCastSafely(i_type, NPY_INTP) &&
2216-
PyArray_SIZE({i_name}) > 0) {{
2217-
npy_int64 min_val, max_val;
2218-
PyObject* py_min_val = PyArray_Min({i_name}, NPY_RAVEL_AXIS,
2219-
NULL);
2220-
if (py_min_val == NULL) {{
2221-
{fail};
2222-
}}
2223-
min_val = PyLong_AsLongLong(py_min_val);
2224-
Py_DECREF(py_min_val);
2225-
if (min_val == -1 && PyErr_Occurred()) {{
2226-
{fail};
2227-
}}
2228-
PyObject* py_max_val = PyArray_Max({i_name}, NPY_RAVEL_AXIS,
2229-
NULL);
2230-
if (py_max_val == NULL) {{
2231-
{fail};
2232-
}}
2233-
max_val = PyLong_AsLongLong(py_max_val);
2234-
Py_DECREF(py_max_val);
2235-
if (max_val == -1 && PyErr_Occurred()) {{
2236-
{fail};
2237-
}}
2238-
if (min_val < NPY_MIN_INTP || max_val > NPY_MAX_INTP) {{
2239-
PyErr_SetString(PyExc_IndexError,
2240-
"Index contains values "
2241-
"that are bigger than the maximum array "
2242-
"size on this system.");
2243-
{fail};
2244-
}}
2245-
}}
2246-
indices = (PyArrayObject*) PyArray_Cast({i_name}, NPY_INTP);
2247-
if (indices == NULL) {{
2248-
{fail};
2249-
}}
2250-
}}
2251-
else {{
2252-
indices = {i_name};
2253-
Py_INCREF(indices);
2254-
}}
22552181
if ({output_name} != NULL) {{
22562182
npy_intp nd, i, *shape;
2257-
nd = PyArray_NDIM({a_name}) + PyArray_NDIM(indices) - 1;
2183+
nd = PyArray_NDIM({a_name}) + PyArray_NDIM({i_name}) - 1;
22582184
if (PyArray_NDIM({output_name}) != nd) {{
22592185
Py_CLEAR({output_name});
22602186
}}
22612187
else {{
22622188
shape = PyArray_DIMS({output_name});
2263-
for (i = 0; i < PyArray_NDIM(indices); i++) {{
2264-
if (shape[i] != PyArray_DIMS(indices)[i]) {{
2189+
for (i = 0; i < PyArray_NDIM({i_name}); i++) {{
2190+
if (shape[i] != PyArray_DIMS({i_name})[i]) {{
22652191
Py_CLEAR({output_name});
22662192
break;
22672193
}}
22682194
}}
22692195
if ({output_name} != NULL) {{
22702196
for (; i < nd; i++) {{
22712197
if (shape[i] != PyArray_DIMS({a_name})[
2272-
i-PyArray_NDIM(indices)+1]) {{
2198+
i-PyArray_NDIM({i_name})+1]) {{
22732199
Py_CLEAR({output_name});
22742200
break;
22752201
}}
@@ -2278,13 +2204,12 @@ def c_code(self, node, name, input_names, output_names, sub):
22782204
}}
22792205
}}
22802206
{output_name} = (PyArrayObject*)PyArray_TakeFrom(
2281-
{a_name}, (PyObject*)indices, 0, {output_name}, NPY_RAISE);
2282-
Py_DECREF(indices);
2207+
{a_name}, (PyObject*){i_name}, 0, {output_name}, NPY_RAISE);
22832208
if ({output_name} == NULL) {fail};
22842209
"""
22852210

22862211
def c_code_cache_version(self):
2287-
return (0, 1, 2, 3)
2212+
return 1
22882213

22892214

22902215
advanced_subtensor1 = AdvancedSubtensor1()

0 commit comments

Comments
 (0)