Skip to content

Commit 634d13e

Browse files
Replace use of np.MAXDIMS
`np.MAXDIMS` was removed from the public API and no replacement is given in the migration docs. In numpy <= 1.26, the value of `np.MAXDIMS` was 32. This was often used as a flag to mean `axis=None`. In numpy >= 2.0, the maximum number of dims of an array has been increased to 64; simultaneously, a constant `NPY_RAVEL_AXIS` was added to the C-API to indicate that `axis=None`. In most cases, the use of `np.MAXDIMS` to check for `axis=None` can be replaced by the new constant `NPY_RAVEL_AXIS`. To make this constant accessible when using numpy <= 1.26, I added a function to insert `npy_2_compat.h` into the support code for the affected ops.
1 parent 4e8c52d commit 634d13e

File tree

7 files changed

+319
-39
lines changed

7 files changed

+319
-39
lines changed

pytensor/npy_2_compat.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
from textwrap import dedent
2+
3+
4+
def npy_2_compat_header() -> str:
5+
return dedent("""
6+
#ifndef NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_
7+
#define NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_
8+
9+
10+
/*
11+
* This header is meant to be included by downstream directly for 1.x compat.
12+
* In that case we need to ensure that users first included the full headers
13+
* and not just `ndarraytypes.h`.
14+
*/
15+
16+
#ifndef NPY_FEATURE_VERSION
17+
#error "The NumPy 2 compat header requires `import_array()` for which " \\
18+
"the `ndarraytypes.h` header include is not sufficient. Please " \\
19+
"include it after `numpy/ndarrayobject.h` or similar." \\
20+
"" \\
21+
"To simplify inclusion, you may use `PyArray_ImportNumPy()` " \\
22+
"which is defined in the compat header and is lightweight (can be)."
23+
#endif
24+
25+
#if NPY_ABI_VERSION < 0x02000000
26+
/*
27+
* Define 2.0 feature version as it is needed below to decide whether we
28+
* compile for both 1.x and 2.x (defining it gaurantees 1.x only).
29+
*/
30+
#define NPY_2_0_API_VERSION 0x00000012
31+
/*
32+
* If we are compiling with NumPy 1.x, PyArray_RUNTIME_VERSION so we
33+
* pretend the `PyArray_RUNTIME_VERSION` is `NPY_FEATURE_VERSION`.
34+
* This allows downstream to use `PyArray_RUNTIME_VERSION` if they need to.
35+
*/
36+
#define PyArray_RUNTIME_VERSION NPY_FEATURE_VERSION
37+
/* Compiling on NumPy 1.x where these are the same: */
38+
#define PyArray_DescrProto PyArray_Descr
39+
#endif
40+
41+
42+
/*
43+
* Define a better way to call `_import_array()` to simplify backporting as
44+
* we now require imports more often (necessary to make ABI flexible).
45+
*/
46+
#ifdef import_array1
47+
48+
static inline int
49+
PyArray_ImportNumPyAPI()
50+
{
51+
if (NPY_UNLIKELY(PyArray_API == NULL)) {
52+
import_array1(-1);
53+
}
54+
return 0;
55+
}
56+
57+
#endif /* import_array1 */
58+
59+
60+
/*
61+
* NPY_DEFAULT_INT
62+
*
63+
* The default integer has changed, `NPY_DEFAULT_INT` is available at runtime
64+
* for use as type number, e.g. `PyArray_DescrFromType(NPY_DEFAULT_INT)`.
65+
*
66+
* NPY_RAVEL_AXIS
67+
*
68+
* This was introduced in NumPy 2.0 to allow indicating that an axis should be
69+
* raveled in an operation. Before NumPy 2.0, NPY_MAXDIMS was used for this purpose.
70+
*
71+
* NPY_MAXDIMS
72+
*
73+
* A constant indicating the maximum number dimensions allowed when creating
74+
* an ndarray.
75+
*
76+
* NPY_NTYPES_LEGACY
77+
*
78+
* The number of built-in NumPy dtypes.
79+
*/
80+
#if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION
81+
#define NPY_DEFAULT_INT NPY_INTP
82+
#define NPY_RAVEL_AXIS NPY_MIN_INT
83+
#define NPY_MAXARGS 64
84+
85+
#elif NPY_ABI_VERSION < 0x02000000
86+
#define NPY_DEFAULT_INT NPY_LONG
87+
#define NPY_RAVEL_AXIS 32
88+
#define NPY_MAXARGS 32
89+
90+
/* Aliases of 2.x names to 1.x only equivalent names */
91+
#define NPY_NTYPES NPY_NTYPES_LEGACY
92+
#define PyArray_DescrProto PyArray_Descr
93+
#define _PyArray_LegacyDescr PyArray_Descr
94+
/* NumPy 2 definition always works, but add it for 1.x only */
95+
#define PyDataType_ISLEGACY(dtype) (1)
96+
#else
97+
#define NPY_DEFAULT_INT \\
98+
(PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? NPY_INTP : NPY_LONG)
99+
#define NPY_RAVEL_AXIS \\
100+
(PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? -1 : 32)
101+
#define NPY_MAXARGS \\
102+
(PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? 64 : 32)
103+
#endif
104+
105+
106+
/*
107+
* Access inline functions for descriptor fields. Except for the first
108+
* few fields, these needed to be moved (elsize, alignment) for
109+
* additional space. Or they are descriptor specific and are not generally
110+
* available anymore (metadata, c_metadata, subarray, names, fields).
111+
*
112+
* Most of these are defined via the `DESCR_ACCESSOR` macro helper.
113+
*/
114+
#if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION || NPY_ABI_VERSION < 0x02000000
115+
/* Compiling for 1.x or 2.x only, direct field access is OK: */
116+
117+
static inline void
118+
PyDataType_SET_ELSIZE(PyArray_Descr *dtype, npy_intp size)
119+
{
120+
dtype->elsize = size;
121+
}
122+
123+
static inline npy_uint64
124+
PyDataType_FLAGS(const PyArray_Descr *dtype)
125+
{
126+
#if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION
127+
return dtype->flags;
128+
#else
129+
return (unsigned char)dtype->flags; /* Need unsigned cast on 1.x */
130+
#endif
131+
}
132+
133+
#define DESCR_ACCESSOR(FIELD, field, type, legacy_only) \\
134+
static inline type \\
135+
PyDataType_##FIELD(const PyArray_Descr *dtype) { \\
136+
if (legacy_only && !PyDataType_ISLEGACY(dtype)) { \\
137+
return (type)0; \\
138+
} \\
139+
return ((_PyArray_LegacyDescr *)dtype)->field; \\
140+
}
141+
#else /* compiling for both 1.x and 2.x */
142+
143+
static inline void
144+
PyDataType_SET_ELSIZE(PyArray_Descr *dtype, npy_intp size)
145+
{
146+
if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) {
147+
((_PyArray_DescrNumPy2 *)dtype)->elsize = size;
148+
}
149+
else {
150+
((PyArray_DescrProto *)dtype)->elsize = (int)size;
151+
}
152+
}
153+
154+
static inline npy_uint64
155+
PyDataType_FLAGS(const PyArray_Descr *dtype)
156+
{
157+
if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) {
158+
return ((_PyArray_DescrNumPy2 *)dtype)->flags;
159+
}
160+
else {
161+
return (unsigned char)((PyArray_DescrProto *)dtype)->flags;
162+
}
163+
}
164+
165+
/* Cast to LegacyDescr always fine but needed when `legacy_only` */
166+
#define DESCR_ACCESSOR(FIELD, field, type, legacy_only) \\
167+
static inline type \\
168+
PyDataType_##FIELD(const PyArray_Descr *dtype) { \\
169+
if (legacy_only && !PyDataType_ISLEGACY(dtype)) { \\
170+
return (type)0; \\
171+
} \\
172+
if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) { \\
173+
return ((_PyArray_LegacyDescr *)dtype)->field; \\
174+
} \\
175+
else { \\
176+
return ((PyArray_DescrProto *)dtype)->field; \\
177+
} \\
178+
}
179+
#endif
180+
181+
DESCR_ACCESSOR(ELSIZE, elsize, npy_intp, 0)
182+
DESCR_ACCESSOR(ALIGNMENT, alignment, npy_intp, 0)
183+
DESCR_ACCESSOR(METADATA, metadata, PyObject *, 1)
184+
DESCR_ACCESSOR(SUBARRAY, subarray, PyArray_ArrayDescr *, 1)
185+
DESCR_ACCESSOR(NAMES, names, PyObject *, 1)
186+
DESCR_ACCESSOR(FIELDS, fields, PyObject *, 1)
187+
DESCR_ACCESSOR(C_METADATA, c_metadata, NpyAuxData *, 1)
188+
189+
#undef DESCR_ACCESSOR
190+
191+
192+
#if !(defined(NPY_INTERNAL_BUILD) && NPY_INTERNAL_BUILD)
193+
#if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION
194+
static inline PyArray_ArrFuncs *
195+
PyDataType_GetArrFuncs(const PyArray_Descr *descr)
196+
{
197+
return _PyDataType_GetArrFuncs(descr);
198+
}
199+
#elif NPY_ABI_VERSION < 0x02000000
200+
static inline PyArray_ArrFuncs *
201+
PyDataType_GetArrFuncs(const PyArray_Descr *descr)
202+
{
203+
return descr->f;
204+
}
205+
#else
206+
static inline PyArray_ArrFuncs *
207+
PyDataType_GetArrFuncs(const PyArray_Descr *descr)
208+
{
209+
if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) {
210+
return _PyDataType_GetArrFuncs(descr);
211+
}
212+
else {
213+
return ((PyArray_DescrProto *)descr)->f;
214+
}
215+
}
216+
#endif
217+
218+
219+
#endif /* not internal build */
220+
221+
#endif /* NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_ */
222+
223+
""")

pytensor/tensor/elemwise.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# numpy < 2.0
1313
from numpy.core.numeric import normalize_axis_tuple
1414

15+
1516
import pytensor.tensor.basic
1617
from pytensor.configdefaults import config
1718
from pytensor.gradient import DisconnectedType

pytensor/tensor/extra_ops.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525
from pytensor.link.c.op import COp
2626
from pytensor.link.c.params_type import ParamsType
2727
from pytensor.link.c.type import EnumList, Generic
28+
from pytensor.npy_2_compat import npy_2_compat_header
2829
from pytensor.raise_op import Assert
29-
from pytensor.scalar import int32 as int_t
30+
from pytensor.scalar import int64 as int_t
3031
from pytensor.scalar import upcast
3132
from pytensor.tensor import TensorLike, as_tensor_variable
3233
from pytensor.tensor import basic as ptb
@@ -306,7 +307,14 @@ def __init__(self, axis: int | None = None, mode="add"):
306307
self.axis = axis
307308
self.mode = mode
308309

309-
c_axis = property(lambda self: np.MAXDIMS if self.axis is None else self.axis)
310+
@property
311+
def c_axis(self) -> int:
312+
if self.axis is None:
313+
if np.__version__ < "2":
314+
return 32 # value used to mark axis = None in Numpy C-API prior to version 2.0
315+
else:
316+
return np.iinfo(np.int32).min # the value of "NPY_RAVEL_AXIS"
317+
return self.axis
310318

311319
def make_node(self, x):
312320
x = ptb.as_tensor_variable(x)
@@ -363,24 +371,38 @@ def infer_shape(self, fgraph, node, shapes):
363371

364372
return shapes
365373

374+
def c_support_code_apply(self, node: Apply, name: str) -> str:
375+
"""Needed to define NPY_RAVEL_AXIS"""
376+
return npy_2_compat_header()
377+
366378
def c_code(self, node, name, inames, onames, sub):
367379
(x,) = inames
368380
(z,) = onames
369381
fail = sub["fail"]
370382
params = sub["params"]
371383

372-
code = f"""
373-
int axis = {params}->c_axis;
384+
if self.axis is None:
385+
axis_code = "int axis = NPY_RAVEL_AXIS;\n"
386+
else:
387+
axis_code = "int axis = {params}->c_axis;\n"
388+
389+
code = (
390+
axis_code
391+
+ """
392+
#undef NPY_UF_DBG_TRACING
393+
#define NPY_UF_DBG_TRACING 1
394+
374395
if (axis == 0 && PyArray_NDIM({x}) == 1)
375-
axis = NPY_MAXDIMS;
396+
axis = NPY_RAVEL_AXIS;
376397
npy_intp shape[1] = {{ PyArray_SIZE({x}) }};
377-
if(axis == NPY_MAXDIMS && !({z} && PyArray_DIMS({z})[0] == shape[0]))
398+
if(axis == NPY_RAVEL_AXIS && !({z} && PyArray_DIMS({z})[0] == shape[0]))
378399
{{
379400
Py_XDECREF({z});
380-
{z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE((PyArrayObject*) py_{x}));
401+
{z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE({x}));
402+
//{z} = (PyArrayObject*) PyArray_NewLikeArray((PyArrayObject*) PyArray_Ravel({x}, NPY_ANYORDER), NPY_ANYORDER, NULL, 0);
381403
}}
382404
383-
else if(axis != NPY_MAXDIMS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x}))))
405+
else if(axis != NPY_RAVEL_AXIS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x}))))
384406
{{
385407
Py_XDECREF({z});
386408
{z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x}));
@@ -407,11 +429,12 @@ def c_code(self, node, name, inames, onames, sub):
407429
Py_XDECREF(t);
408430
}}
409431
"""
432+
).format(**locals())
410433

411434
return code
412435

413436
def c_code_cache_version(self):
414-
return (8,)
437+
return (9,)
415438

416439
def __str__(self):
417440
return f"{self.__class__.__name__}{{{self.axis}, {self.mode}}}"

pytensor/tensor/math.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from pytensor.graph.replace import _vectorize_node
2121
from pytensor.link.c.op import COp
2222
from pytensor.link.c.params_type import ParamsType
23+
from pytensor.npy_2_compat import npy_2_compat_header
2324
from pytensor.printing import pprint
2425
from pytensor.raise_op import Assert
2526
from pytensor.scalar.basic import BinaryScalarOp
@@ -166,7 +167,10 @@ def get_params(self, node):
166167
c_axis = np.int64(self.axis[0])
167168
else:
168169
# The value here doesn't matter, it won't be used
169-
c_axis = np.int64(-1)
170+
if np.__version__ < "2":
171+
c_axis = np.int64(-1)
172+
else:
173+
c_axis = -2147483648 # the value of "NPY_RAVEL_AXIS"
170174
return self.params_type.get_params(c_axis=c_axis)
171175

172176
def make_node(self, x):
@@ -209,13 +213,17 @@ def perform(self, node, inp, outs):
209213

210214
max_idx[0] = np.asarray(np.argmax(reshaped_x, axis=-1), dtype="int64")
211215

216+
def c_support_code_apply(self, node: Apply, name: str) -> str:
217+
"""Needed to define NPY_RAVEL_AXIS"""
218+
return npy_2_compat_header()
219+
212220
def c_code(self, node, name, inp, out, sub):
213221
(x,) = inp
214222
(argmax,) = out
215223
fail = sub["fail"]
216224
params = sub["params"]
217225
if self.axis is None:
218-
axis_code = "axis = NPY_MAXDIMS;"
226+
axis_code = "axis = NPY_RAVEL_AXIS;"
219227
else:
220228
if len(self.axis) != 1:
221229
raise NotImplementedError()

0 commit comments

Comments
 (0)