Skip to content

Commit ad135cf

Browse files
ArmavicaricardoV94
authored andcommitted
Remove the numpy < 2 compatibility header
1 parent 31cc397 commit ad135cf

File tree

4 files changed

+6
-254
lines changed

4 files changed

+6
-254
lines changed

pytensor/npy_2_compat.py

Lines changed: 0 additions & 226 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from textwrap import dedent
2-
31
import numpy as np
42

53

@@ -22,227 +20,3 @@ def old_np_unique(
2220
outs[inv_idx] = outs[inv_idx].reshape(inv_shape)
2321

2422
return tuple(outs)
25-
26-
27-
# compatibility header for C code
28-
def npy_2_compat_header() -> str:
29-
"""Compatibility header that Numpy suggests is vendored with code that uses Numpy < 2.0 and Numpy 2.x"""
30-
return dedent("""
31-
#ifndef NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_
32-
#define NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_
33-
34-
35-
/*
36-
* This header is meant to be included by downstream directly for 1.x compat.
37-
* In that case we need to ensure that users first included the full headers
38-
* and not just `ndarraytypes.h`.
39-
*/
40-
41-
#ifndef NPY_FEATURE_VERSION
42-
#error "The NumPy 2 compat header requires `import_array()` for which " \\
43-
"the `ndarraytypes.h` header include is not sufficient. Please " \\
44-
"include it after `numpy/ndarrayobject.h` or similar." \\
45-
"" \\
46-
"To simplify inclusion, you may use `PyArray_ImportNumPy()` " \\
47-
"which is defined in the compat header and is lightweight (can be)."
48-
#endif
49-
50-
#if NPY_ABI_VERSION < 0x02000000
51-
/*
52-
* Define 2.0 feature version as it is needed below to decide whether we
53-
* compile for both 1.x and 2.x (defining it gaurantees 1.x only).
54-
*/
55-
#define NPY_2_0_API_VERSION 0x00000012
56-
/*
57-
* If we are compiling with NumPy 1.x, PyArray_RUNTIME_VERSION so we
58-
* pretend the `PyArray_RUNTIME_VERSION` is `NPY_FEATURE_VERSION`.
59-
* This allows downstream to use `PyArray_RUNTIME_VERSION` if they need to.
60-
*/
61-
#define PyArray_RUNTIME_VERSION NPY_FEATURE_VERSION
62-
/* Compiling on NumPy 1.x where these are the same: */
63-
#define PyArray_DescrProto PyArray_Descr
64-
#endif
65-
66-
67-
/*
68-
* Define a better way to call `_import_array()` to simplify backporting as
69-
* we now require imports more often (necessary to make ABI flexible).
70-
*/
71-
#ifdef import_array1
72-
73-
static inline int
74-
PyArray_ImportNumPyAPI()
75-
{
76-
if (NPY_UNLIKELY(PyArray_API == NULL)) {
77-
import_array1(-1);
78-
}
79-
return 0;
80-
}
81-
82-
#endif /* import_array1 */
83-
84-
85-
/*
86-
* NPY_DEFAULT_INT
87-
*
88-
* The default integer has changed, `NPY_DEFAULT_INT` is available at runtime
89-
* for use as type number, e.g. `PyArray_DescrFromType(NPY_DEFAULT_INT)`.
90-
*
91-
* NPY_RAVEL_AXIS
92-
*
93-
* This was introduced in NumPy 2.0 to allow indicating that an axis should be
94-
* raveled in an operation. Before NumPy 2.0, NPY_MAXDIMS was used for this purpose.
95-
*
96-
* NPY_MAXDIMS
97-
*
98-
* A constant indicating the maximum number dimensions allowed when creating
99-
* an ndarray.
100-
*
101-
* NPY_NTYPES_LEGACY
102-
*
103-
* The number of built-in NumPy dtypes.
104-
*/
105-
#if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION
106-
#define NPY_DEFAULT_INT NPY_INTP
107-
#define NPY_RAVEL_AXIS NPY_MIN_INT
108-
#define NPY_MAXARGS 64
109-
110-
#elif NPY_ABI_VERSION < 0x02000000
111-
#define NPY_DEFAULT_INT NPY_LONG
112-
#define NPY_RAVEL_AXIS 32
113-
#define NPY_MAXARGS 32
114-
115-
/* Aliases of 2.x names to 1.x only equivalent names */
116-
#define NPY_NTYPES NPY_NTYPES_LEGACY
117-
#define PyArray_DescrProto PyArray_Descr
118-
#define _PyArray_LegacyDescr PyArray_Descr
119-
/* NumPy 2 definition always works, but add it for 1.x only */
120-
#define PyDataType_ISLEGACY(dtype) (1)
121-
#else
122-
#define NPY_DEFAULT_INT \\
123-
(PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? NPY_INTP : NPY_LONG)
124-
#define NPY_RAVEL_AXIS \\
125-
(PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? -1 : 32)
126-
#define NPY_MAXARGS \\
127-
(PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? 64 : 32)
128-
#endif
129-
130-
131-
/*
132-
* Access inline functions for descriptor fields. Except for the first
133-
* few fields, these needed to be moved (elsize, alignment) for
134-
* additional space. Or they are descriptor specific and are not generally
135-
* available anymore (metadata, c_metadata, subarray, names, fields).
136-
*
137-
* Most of these are defined via the `DESCR_ACCESSOR` macro helper.
138-
*/
139-
#if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION || NPY_ABI_VERSION < 0x02000000
140-
/* Compiling for 1.x or 2.x only, direct field access is OK: */
141-
142-
static inline void
143-
PyDataType_SET_ELSIZE(PyArray_Descr *dtype, npy_intp size)
144-
{
145-
dtype->elsize = size;
146-
}
147-
148-
static inline npy_uint64
149-
PyDataType_FLAGS(const PyArray_Descr *dtype)
150-
{
151-
#if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION
152-
return dtype->flags;
153-
#else
154-
return (unsigned char)dtype->flags; /* Need unsigned cast on 1.x */
155-
#endif
156-
}
157-
158-
#define DESCR_ACCESSOR(FIELD, field, type, legacy_only) \\
159-
static inline type \\
160-
PyDataType_##FIELD(const PyArray_Descr *dtype) { \\
161-
if (legacy_only && !PyDataType_ISLEGACY(dtype)) { \\
162-
return (type)0; \\
163-
} \\
164-
return ((_PyArray_LegacyDescr *)dtype)->field; \\
165-
}
166-
#else /* compiling for both 1.x and 2.x */
167-
168-
static inline void
169-
PyDataType_SET_ELSIZE(PyArray_Descr *dtype, npy_intp size)
170-
{
171-
if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) {
172-
((_PyArray_DescrNumPy2 *)dtype)->elsize = size;
173-
}
174-
else {
175-
((PyArray_DescrProto *)dtype)->elsize = (int)size;
176-
}
177-
}
178-
179-
static inline npy_uint64
180-
PyDataType_FLAGS(const PyArray_Descr *dtype)
181-
{
182-
if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) {
183-
return ((_PyArray_DescrNumPy2 *)dtype)->flags;
184-
}
185-
else {
186-
return (unsigned char)((PyArray_DescrProto *)dtype)->flags;
187-
}
188-
}
189-
190-
/* Cast to LegacyDescr always fine but needed when `legacy_only` */
191-
#define DESCR_ACCESSOR(FIELD, field, type, legacy_only) \\
192-
static inline type \\
193-
PyDataType_##FIELD(const PyArray_Descr *dtype) { \\
194-
if (legacy_only && !PyDataType_ISLEGACY(dtype)) { \\
195-
return (type)0; \\
196-
} \\
197-
if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) { \\
198-
return ((_PyArray_LegacyDescr *)dtype)->field; \\
199-
} \\
200-
else { \\
201-
return ((PyArray_DescrProto *)dtype)->field; \\
202-
} \\
203-
}
204-
#endif
205-
206-
DESCR_ACCESSOR(ELSIZE, elsize, npy_intp, 0)
207-
DESCR_ACCESSOR(ALIGNMENT, alignment, npy_intp, 0)
208-
DESCR_ACCESSOR(METADATA, metadata, PyObject *, 1)
209-
DESCR_ACCESSOR(SUBARRAY, subarray, PyArray_ArrayDescr *, 1)
210-
DESCR_ACCESSOR(NAMES, names, PyObject *, 1)
211-
DESCR_ACCESSOR(FIELDS, fields, PyObject *, 1)
212-
DESCR_ACCESSOR(C_METADATA, c_metadata, NpyAuxData *, 1)
213-
214-
#undef DESCR_ACCESSOR
215-
216-
217-
#if !(defined(NPY_INTERNAL_BUILD) && NPY_INTERNAL_BUILD)
218-
#if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION
219-
static inline PyArray_ArrFuncs *
220-
PyDataType_GetArrFuncs(const PyArray_Descr *descr)
221-
{
222-
return _PyDataType_GetArrFuncs(descr);
223-
}
224-
#elif NPY_ABI_VERSION < 0x02000000
225-
static inline PyArray_ArrFuncs *
226-
PyDataType_GetArrFuncs(const PyArray_Descr *descr)
227-
{
228-
return descr->f;
229-
}
230-
#else
231-
static inline PyArray_ArrFuncs *
232-
PyDataType_GetArrFuncs(const PyArray_Descr *descr)
233-
{
234-
if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) {
235-
return _PyDataType_GetArrFuncs(descr);
236-
}
237-
else {
238-
return ((PyArray_DescrProto *)descr)->f;
239-
}
240-
}
241-
#endif
242-
243-
244-
#endif /* not internal build */
245-
246-
#endif /* NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_ */
247-
248-
""")

pytensor/tensor/extra_ops.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pytensor.link.c.op import COp
1919
from pytensor.link.c.params_type import ParamsType
2020
from pytensor.link.c.type import EnumList, Generic
21-
from pytensor.npy_2_compat import npy_2_compat_header, old_np_unique
21+
from pytensor.npy_2_compat import old_np_unique
2222
from pytensor.raise_op import Assert
2323
from pytensor.scalar import int64 as int_t
2424
from pytensor.scalar import upcast
@@ -362,10 +362,6 @@ def infer_shape(self, fgraph, node, shapes):
362362

363363
return shapes
364364

365-
def c_support_code_apply(self, node: Apply, name: str) -> str:
366-
"""Needed to define NPY_RAVEL_AXIS"""
367-
return npy_2_compat_header()
368-
369365
def c_code(self, node, name, inames, onames, sub):
370366
(x,) = inames
371367
(z,) = onames
@@ -424,7 +420,7 @@ def c_code(self, node, name, inames, onames, sub):
424420
return code
425421

426422
def c_code_cache_version(self):
427-
return (9,)
423+
return (10,)
428424

429425
def __str__(self):
430426
return f"{self.__class__.__name__}{{{self.axis}, {self.mode}}}"

pytensor/tensor/math.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from pytensor.graph.replace import _vectorize_node
1515
from pytensor.link.c.op import COp
1616
from pytensor.link.c.params_type import ParamsType
17-
from pytensor.npy_2_compat import npy_2_compat_header
1817
from pytensor.printing import pprint
1918
from pytensor.raise_op import Assert
2019
from pytensor.scalar.basic import BinaryScalarOp
@@ -205,10 +204,6 @@ def perform(self, node, inp, outs):
205204

206205
max_idx[0] = np.asarray(np.argmax(reshaped_x, axis=-1), dtype="int64")
207206

208-
def c_support_code_apply(self, node: Apply, name: str) -> str:
209-
"""Needed to define NPY_RAVEL_AXIS"""
210-
return npy_2_compat_header()
211-
212207
def c_code(self, node, name, inp, out, sub):
213208
(x,) = inp
214209
(argmax,) = out
@@ -255,7 +250,7 @@ def c_code(self, node, name, inp, out, sub):
255250
"""
256251

257252
def c_code_cache_version(self):
258-
return (2,)
253+
return (3,)
259254

260255
def infer_shape(self, fgraph, node, shapes):
261256
(ishape,) = shapes

pytensor/tensor/special.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from pytensor.graph.basic import Apply
77
from pytensor.graph.replace import _vectorize_node
88
from pytensor.link.c.op import COp
9-
from pytensor.npy_2_compat import npy_2_compat_header
109
from pytensor.tensor.basic import as_tensor_variable
1110
from pytensor.tensor.elemwise import get_normalized_batch_axes
1211
from pytensor.tensor.math import gamma, gammaln, log, neg, sum
@@ -61,11 +60,7 @@ def infer_shape(self, fgraph, node, shape):
6160
return [shape[1]]
6261

6362
def c_code_cache_version(self):
64-
return (5,)
65-
66-
def c_support_code_apply(self, node: Apply, name: str) -> str:
67-
# return super().c_support_code_apply(node, name)
68-
return npy_2_compat_header()
63+
return (6,)
6964

7065
def c_code(self, node, name, inp, out, sub):
7166
dy, sm = inp
@@ -296,10 +291,6 @@ def infer_shape(self, fgraph, node, shape):
296291
def c_headers(self, **kwargs):
297292
return ["<cmath>"]
298293

299-
def c_support_code_apply(self, node: Apply, name: str) -> str:
300-
"""Needed to define NPY_RAVEL_AXIS"""
301-
return npy_2_compat_header()
302-
303294
def c_code(self, node, name, inp, out, sub):
304295
(x,) = inp
305296
(sm,) = out
@@ -495,7 +486,7 @@ def c_code(self, node, name, inp, out, sub):
495486

496487
@staticmethod
497488
def c_code_cache_version():
498-
return (5,)
489+
return (6,)
499490

500491

501492
def softmax(c, axis=None):
@@ -555,10 +546,6 @@ def infer_shape(self, fgraph, node, shape):
555546
def c_headers(self, **kwargs):
556547
return ["<cmath>"]
557548

558-
def c_support_code_apply(self, node: Apply, name: str) -> str:
559-
"""Needed to define NPY_RAVEL_AXIS"""
560-
return npy_2_compat_header()
561-
562549
def c_code(self, node, name, inp, out, sub):
563550
(x,) = inp
564551
(sm,) = out
@@ -750,7 +737,7 @@ def c_code(self, node, name, inp, out, sub):
750737

751738
@staticmethod
752739
def c_code_cache_version():
753-
return (2,)
740+
return (3,)
754741

755742

756743
def log_softmax(c, axis=None):

0 commit comments

Comments
 (0)