33import warnings
44from collections .abc import Callable , Iterable , Sequence
55from itertools import chain , groupby
6- from textwrap import dedent
76from typing import cast , overload
87
98import numpy as np
1918from pytensor .graph .utils import MethodNotDefined
2019from pytensor .link .c .op import COp
2120from pytensor .link .c .params_type import ParamsType
22- from pytensor .npy_2_compat import npy_2_compat_header , numpy_version , using_numpy_2
21+ from pytensor .npy_2_compat import numpy_version , using_numpy_2
2322from pytensor .printing import Printer , pprint , set_precedence
2423from pytensor .scalar .basic import ScalarConstant , ScalarVariable
2524from pytensor .tensor import (
@@ -2130,24 +2129,6 @@ def perform(self, node, inp, out_):
21302129 else :
21312130 o = None
21322131
2133- # If i.dtype is more precise than numpy.intp (int32 on 32-bit machines,
2134- # int64 on 64-bit machines), numpy may raise the following error:
2135- # TypeError: array cannot be safely cast to required type.
2136- # We need to check if values in i can fit in numpy.intp, because
2137- # if they don't, that should be an error (no array can have that
2138- # many elements on a 32-bit arch).
2139- if i .dtype != np .intp :
2140- i_ = np .asarray (i , dtype = np .intp )
2141- if not np .can_cast (i .dtype , np .intp ):
2142- # Check if there was actually an incorrect conversion
2143- if np .any (i != i_ ):
2144- raise IndexError (
2145- "index contains values that are bigger "
2146- "than the maximum array size on this system." ,
2147- i ,
2148- )
2149- i = i_
2150-
21512132 out [0 ] = x .take (i , axis = 0 , out = o )
21522133
21532134 def connection_pattern (self , node ):
@@ -2187,16 +2168,6 @@ def infer_shape(self, fgraph, node, ishapes):
21872168 x , ilist = ishapes
21882169 return [ilist + x [1 :]]
21892170
2190- def c_support_code (self , ** kwargs ):
2191- # In some versions of numpy, NPY_MIN_INTP is defined as MIN_LONG,
2192- # which is not defined. It should be NPY_MIN_LONG instead in that case.
2193- return npy_2_compat_header () + dedent (
2194- """\
2195- #ifndef MIN_LONG
2196- #define MIN_LONG NPY_MIN_LONG
2197- #endif"""
2198- )
2199-
22002171 def c_code (self , node , name , input_names , output_names , sub ):
22012172 if self .__class__ is not AdvancedSubtensor1 :
22022173 raise MethodNotDefined (
@@ -2207,69 +2178,24 @@ def c_code(self, node, name, input_names, output_names, sub):
22072178 output_name = output_names [0 ]
22082179 fail = sub ["fail" ]
22092180 return f"""
2210- PyArrayObject *indices;
2211- int i_type = PyArray_TYPE({ i_name } );
2212- if (i_type != NPY_INTP) {{
2213- // Cast { i_name } to NPY_INTP (expected by PyArray_TakeFrom),
2214- // if all values fit.
2215- if (!PyArray_CanCastSafely(i_type, NPY_INTP) &&
2216- PyArray_SIZE({ i_name } ) > 0) {{
2217- npy_int64 min_val, max_val;
2218- PyObject* py_min_val = PyArray_Min({ i_name } , NPY_RAVEL_AXIS,
2219- NULL);
2220- if (py_min_val == NULL) {{
2221- { fail } ;
2222- }}
2223- min_val = PyLong_AsLongLong(py_min_val);
2224- Py_DECREF(py_min_val);
2225- if (min_val == -1 && PyErr_Occurred()) {{
2226- { fail } ;
2227- }}
2228- PyObject* py_max_val = PyArray_Max({ i_name } , NPY_RAVEL_AXIS,
2229- NULL);
2230- if (py_max_val == NULL) {{
2231- { fail } ;
2232- }}
2233- max_val = PyLong_AsLongLong(py_max_val);
2234- Py_DECREF(py_max_val);
2235- if (max_val == -1 && PyErr_Occurred()) {{
2236- { fail } ;
2237- }}
2238- if (min_val < NPY_MIN_INTP || max_val > NPY_MAX_INTP) {{
2239- PyErr_SetString(PyExc_IndexError,
2240- "Index contains values "
2241- "that are bigger than the maximum array "
2242- "size on this system.");
2243- { fail } ;
2244- }}
2245- }}
2246- indices = (PyArrayObject*) PyArray_Cast({ i_name } , NPY_INTP);
2247- if (indices == NULL) {{
2248- { fail } ;
2249- }}
2250- }}
2251- else {{
2252- indices = { i_name } ;
2253- Py_INCREF(indices);
2254- }}
22552181 if ({ output_name } != NULL) {{
22562182 npy_intp nd, i, *shape;
2257- nd = PyArray_NDIM({ a_name } ) + PyArray_NDIM(indices ) - 1;
2183+ nd = PyArray_NDIM({ a_name } ) + PyArray_NDIM({ i_name } ) - 1;
22582184 if (PyArray_NDIM({ output_name } ) != nd) {{
22592185 Py_CLEAR({ output_name } );
22602186 }}
22612187 else {{
22622188 shape = PyArray_DIMS({ output_name } );
2263- for (i = 0; i < PyArray_NDIM(indices ); i++) {{
2264- if (shape[i] != PyArray_DIMS(indices )[i]) {{
2189+ for (i = 0; i < PyArray_NDIM({ i_name } ); i++) {{
2190+ if (shape[i] != PyArray_DIMS({ i_name } )[i]) {{
22652191 Py_CLEAR({ output_name } );
22662192 break;
22672193 }}
22682194 }}
22692195 if ({ output_name } != NULL) {{
22702196 for (; i < nd; i++) {{
22712197 if (shape[i] != PyArray_DIMS({ a_name } )[
2272- i-PyArray_NDIM(indices )+1]) {{
2198+ i-PyArray_NDIM({ i_name } )+1]) {{
22732199 Py_CLEAR({ output_name } );
22742200 break;
22752201 }}
@@ -2278,13 +2204,12 @@ def c_code(self, node, name, input_names, output_names, sub):
22782204 }}
22792205 }}
22802206 { output_name } = (PyArrayObject*)PyArray_TakeFrom(
2281- { a_name } , (PyObject*)indices, 0, { output_name } , NPY_RAISE);
2282- Py_DECREF(indices);
2207+ { a_name } , (PyObject*){ i_name } , 0, { output_name } , NPY_RAISE);
22832208 if ({ output_name } == NULL) { fail } ;
22842209 """
22852210
22862211 def c_code_cache_version (self ):
2287- return ( 0 , 1 , 2 , 3 )
2212+ return 1
22882213
22892214
22902215advanced_subtensor1 = AdvancedSubtensor1 ()
0 commit comments