Skip to content

Commit 1bc3f00

Browse files
committed
update quaddtype to work on numpy 2.0
1 parent 13f5672 commit 1bc3f00

File tree

7 files changed

+46
-35
lines changed

7 files changed

+46
-35
lines changed

quaddtype/quaddtype/src/casts.c

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
1-
#include "casts.h"
1+
#include <Python.h>
2+
3+
#define PY_ARRAY_UNIQUE_SYMBOL quaddtype_ARRAY_API
4+
#define NPY_NO_DEPRECATED_API NPY_2_0_API_VERSION
5+
#define NPY_TARGET_VERSION NPY_2_0_API_VERSION
6+
#define NO_IMPORT_ARRAY
7+
#include "numpy/ndarraytypes.h"
8+
#include "numpy/arrayobject.h"
9+
#include "numpy/dtype_api.h"
10+
211
#include "dtype.h"
12+
#include "casts.h"
313

414
// And now the actual cast code! Starting with the "resolver" which tells
515
// us about cast safety.
@@ -131,7 +141,7 @@ static PyArray_DTypeMeta *QuadToQuadDtypes[2] = {NULL, NULL};
131141

132142
static PyType_Slot QuadToQuadSlots[] = {
133143
{NPY_METH_resolve_descriptors, &quad_to_quad_resolve_descriptors},
134-
{_NPY_METH_get_loop, &quad_to_quad_get_loop},
144+
{NPY_METH_get_loop, &quad_to_quad_get_loop},
135145
{0, NULL}};
136146

137147
PyArrayMethod_Spec QuadToQuadCastSpec = {

quaddtype/quaddtype/src/casts.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#ifndef _NPY_CASTS_H
22
#define _NPY_CASTS_H
33

4-
#include "numpy/experimental_dtype_api.h"
54
extern PyArrayMethod_Spec QuadToQuadCastSpec;
65
extern PyArrayMethod_Spec QuadToFloat128CastSpec;
76

quaddtype/quaddtype/src/dtype.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
#include <Python.h>
2+
3+
#define PY_ARRAY_UNIQUE_SYMBOL quaddtype_ARRAY_API
4+
#define NPY_NO_DEPRECATED_API NPY_2_0_API_VERSION
5+
#define NPY_TARGET_VERSION NPY_2_0_API_VERSION
6+
#define NO_IMPORT_ARRAY
7+
#include "numpy/ndarraytypes.h"
8+
#include "numpy/arrayobject.h"
9+
#include "numpy/dtype_api.h"
10+
111
#include "dtype.h"
212
#include "abstract.h"
313
#include "casts.h"

quaddtype/quaddtype/src/dtype.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,6 @@
11
#ifndef _NPY_DTYPE_H
22
#define _NPY_DTYPE_H
33

4-
#include <Python.h>
5-
6-
#define PY_ARRAY_UNIQUE_SYMBOL quaddtype_ARRAY_API
7-
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
8-
#define NO_IMPORT_ARRAY
9-
#include "numpy/ndarraytypes.h"
10-
#include "numpy/arrayobject.h"
11-
#include "numpy/experimental_dtype_api.h"
12-
134
typedef struct {
145
PyArray_Descr base;
156
} QuadDTypeObject;

quaddtype/quaddtype/src/quaddtype_main.c

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#include <Python.h>
22

33
#define PY_ARRAY_UNIQUE_SYMBOL quaddtype_ARRAY_API
4-
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
4+
#define NPY_NO_DEPRECATED_API NPY_2_0_API_VERSION
55
#include "numpy/arrayobject.h"
6-
#include "numpy/experimental_dtype_api.h"
6+
#include "numpy/dtype_api.h"
77

88
#include "dtype.h"
99
#include "umath.h"
@@ -19,16 +19,9 @@ static struct PyModuleDef moduledef = {
1919
PyMODINIT_FUNC
2020
PyInit__quaddtype_main(void)
2121
{
22-
if (_import_array() < 0)
22+
if (import_array() < 0)
2323
return NULL;
2424

25-
// Fail to init if the experimental DType API version 5 isn't supported
26-
if (import_experimental_dtype_api(15) < 0) {
27-
PyErr_SetString(PyExc_ImportError,
28-
"Error encountered importing the experimental dtype API.");
29-
return NULL;
30-
}
31-
3225
PyObject *m = PyModule_Create(&moduledef);
3326
if (m == NULL) {
3427
PyErr_SetString(PyExc_ImportError, "Unable to create the quaddtype_main module.");
@@ -57,11 +50,15 @@ PyInit__quaddtype_main(void)
5750
goto error;
5851
}
5952

60-
if (init_multiply_ufunc() < 0) {
53+
PyObject *numpy = init_multiply_ufunc();
54+
55+
if (numpy == NULL) {
6156
PyErr_SetString(PyExc_TypeError, "Failed to initialize the quadscalar multiply ufunc.");
6257
goto error;
6358
}
6459

60+
Py_DECREF(numpy);
61+
6562
return m;
6663

6764
error:

quaddtype/quaddtype/src/umath.c

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
#include <Python.h>
22

33
#define PY_ARRAY_UNIQUE_SYMBOL quaddtype_ARRAY_API
4-
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
4+
#define NPY_NO_DEPRECATED_API NPY_2_0_API_VERSION
5+
#define NPY_TARGET_VERSION NPY_2_0_API_VERSION
56
#define NO_IMPORT_ARRAY
6-
#include "numpy/arrayobject.h"
77
#include "numpy/ndarraytypes.h"
8+
#include "numpy/arrayobject.h"
89
#include "numpy/ufuncobject.h"
9-
10-
#include "numpy/experimental_dtype_api.h"
10+
#include "numpy/dtype_api.h"
1111

1212
#include "dtype.h"
1313
#include "umath.h"
@@ -73,20 +73,22 @@ quad_multiply_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *dtypes[],
7373
}
7474

7575
// Function that adds our multiply loop to NumPy's multiply ufunc.
76-
int
76+
PyObject*
7777
init_multiply_ufunc(void)
7878
{
79+
import_umath();
80+
7981
// Get the multiply ufunc:
8082
PyObject *numpy = PyImport_ImportModule("numpy");
8183
if (numpy == NULL) {
82-
return -1;
84+
return NULL;
8385
}
86+
8487
PyObject *multiply = PyObject_GetAttrString(numpy, "multiply");
8588

86-
// Why decref here?
87-
Py_DECREF(numpy);
8889
if (multiply == NULL) {
89-
return -1;
90+
Py_DECREF(numpy);
91+
return NULL;
9092
}
9193

9294
// The initializing "wrap up" code from the slides (plus one error check)
@@ -114,8 +116,10 @@ init_multiply_ufunc(void)
114116
/* Register */
115117
if (PyUFunc_AddLoopFromSpec(multiply, &MultiplySpec) < 0) {
116118
Py_DECREF(multiply);
117-
return -1;
119+
Py_DECREF(numpy);
120+
return NULL;
118121
}
122+
119123
Py_DECREF(multiply);
120-
return 0;
124+
return numpy;
121125
}

quaddtype/quaddtype/src/umath.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#ifndef _NPY_UFUNC_H
22
#define _NPY_UFUNC_H
33

4-
int
4+
PyObject *
55
init_multiply_ufunc(void);
66

77
#endif /*_NPY_UFUNC_H */

0 commit comments

Comments
 (0)