Skip to content

Commit 13f5672

Browse files
committed
update unytdtype to work with numpy 2.0
1 parent efe33e2 commit 13f5672

File tree

7 files changed

+41
-43
lines changed

7 files changed

+41
-43
lines changed

unytdtype/unytdtype/src/casts.c

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
#include <Python.h>
22

3+
#include <Python.h>
4+
35
#define PY_ARRAY_UNIQUE_SYMBOL unytdtype_ARRAY_API
4-
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
6+
#define NPY_NO_DEPRECATED_API NPY_2_0_API_VERSION
7+
#define NPY_TARGET_VERSION NPY_2_0_API_VERSION
58
#define NO_IMPORT_ARRAY
69
#include "numpy/arrayobject.h"
7-
#include "numpy/experimental_dtype_api.h"
10+
#include "numpy/dtype_api.h"
811
#include "numpy/ndarraytypes.h"
912

1013
#include "casts.h"
@@ -442,7 +445,7 @@ static PyArray_DTypeMeta *u2u_dtypes[2] = {NULL, NULL};
442445

443446
static PyType_Slot u2u_slots[] = {
444447
{NPY_METH_resolve_descriptors, &unit_to_unit_resolve_descriptors},
445-
{_NPY_METH_get_loop, &unit_to_unit_get_loop},
448+
{NPY_METH_get_loop, &unit_to_unit_get_loop},
446449
{0, NULL}};
447450

448451
static PyArrayMethod_Spec UnitToUnitCastSpec = {
@@ -456,7 +459,7 @@ static PyArrayMethod_Spec UnitToUnitCastSpec = {
456459
};
457460

458461
static PyType_Slot u2f_slots[] = {
459-
{_NPY_METH_get_loop, &unit_to_float64_get_loop}, {0, NULL}};
462+
{NPY_METH_get_loop, &unit_to_float64_get_loop}, {0, NULL}};
460463

461464
static char *u2f_name = "cast_UnytDType_to_Float64";
462465

unytdtype/unytdtype/src/casts.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,6 @@
11
#ifndef _NPY_CASTS_H
22
#define _NPY_CASTS_H
33

4-
#include <Python.h>
5-
6-
#define PY_ARRAY_UNIQUE_SYMBOL unytdtype_ARRAY_API
7-
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
8-
#define NO_IMPORT_ARRAY
9-
#include "numpy/arrayobject.h"
10-
#include "numpy/experimental_dtype_api.h"
11-
#include "numpy/ndarraytypes.h"
12-
134
/* Gets the conversion between two units: */
145
int
156
get_conversion_factor(PyObject *from_unit, PyObject *to_unit, double *factor,

unytdtype/unytdtype/src/dtype.c

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
// clang-format off
2+
#include <Python.h>
3+
#include "structmember.h"
4+
// clang-format on
5+
6+
#define PY_ARRAY_UNIQUE_SYMBOL unytdtype_ARRAY_API
7+
#define NPY_NO_DEPRECATED_API NPY_2_0_API_VERSION
8+
#define NPY_TARGET_VERSION NPY_2_0_API_VERSION
9+
#define NO_IMPORT_ARRAY
10+
#include "numpy/arrayobject.h"
11+
#include "numpy/dtype_api.h"
12+
#include "numpy/ndarraytypes.h"
13+
114
#include "dtype.h"
215

316
#include "casts.h"

unytdtype/unytdtype/src/dtype.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,6 @@
11
#ifndef _NPY_DTYPE_H
22
#define _NPY_DTYPE_H
33

4-
// clang-format off
5-
#include <Python.h>
6-
#include "structmember.h"
7-
// clang-format on
8-
9-
#define PY_ARRAY_UNIQUE_SYMBOL unytdtype_ARRAY_API
10-
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
11-
#define NO_IMPORT_ARRAY
12-
#include "numpy/arrayobject.h"
13-
#include "numpy/experimental_dtype_api.h"
14-
#include "numpy/ndarraytypes.h"
15-
164
typedef struct {
175
PyArray_Descr base;
186
PyObject *unit;

unytdtype/unytdtype/src/umath.c

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
#include <Python.h>
22

33
#define PY_ARRAY_UNIQUE_SYMBOL unytdtype_ARRAY_API
4-
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
4+
#define NPY_NO_DEPRECATED_API NPY_2_0_API_VERSION
5+
#define NPY_TARGET_VERSION NPY_2_0_API_VERSION
56
#define NO_IMPORT_ARRAY
67
#include "numpy/arrayobject.h"
7-
#include "numpy/experimental_dtype_api.h"
8+
#include "numpy/dtype_api.h"
89
#include "numpy/ndarraytypes.h"
910
#include "numpy/ufuncobject.h"
1011

@@ -64,20 +65,22 @@ unit_multiply_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *dtypes[],
6465
/*
6566
* Function that adds our multiply loop to NumPy's multiply ufunc.
6667
*/
67-
int
68+
PyObject *
6869
init_multiply_ufunc(void)
6970
{
71+
import_umath();
72+
7073
/*
7174
* Get the multiply ufunc:
7275
*/
7376
PyObject *numpy = PyImport_ImportModule("numpy");
7477
if (numpy == NULL) {
75-
return -1;
78+
return NULL;
7679
}
7780
PyObject *multiply = PyObject_GetAttrString(numpy, "multiply");
78-
Py_DECREF(numpy);
7981
if (multiply == NULL) {
80-
return -1;
82+
Py_DECREF(numpy);
83+
return NULL;
8184
}
8285

8386
/*
@@ -103,8 +106,9 @@ init_multiply_ufunc(void)
103106
/* Register */
104107
if (PyUFunc_AddLoopFromSpec(multiply, &MultiplySpec) < 0) {
105108
Py_DECREF(multiply);
106-
return -1;
109+
Py_DECREF(numpy);
110+
return NULL;
107111
}
108112
Py_DECREF(multiply);
109-
return 0;
113+
return numpy;
110114
}

unytdtype/unytdtype/src/umath.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#ifndef _NPY_UFUNC_H
22
#define _NPY_UFUNC_H
33

4-
int
4+
PyObject *
55
init_multiply_ufunc(void);
66

77
#endif /*_NPY_UFUNC_H */

unytdtype/unytdtype/src/unytdtype_main.c

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#include <Python.h>
22

33
#define PY_ARRAY_UNIQUE_SYMBOL unytdtype_ARRAY_API
4-
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
4+
#define NPY_NO_DEPRECATED_API NPY_2_0_API_VERSION
55
#include "numpy/arrayobject.h"
6-
#include "numpy/experimental_dtype_api.h"
6+
#include "numpy/dtype_api.h"
77

88
#include "dtype.h"
99
#include "umath.h"
@@ -18,12 +18,7 @@ static struct PyModuleDef moduledef = {
1818
PyMODINIT_FUNC
1919
PyInit__unytdtype_main(void)
2020
{
21-
if (_import_array() < 0) {
22-
return NULL;
23-
}
24-
if (import_experimental_dtype_api(15) < 0) {
25-
return NULL;
26-
}
21+
import_array();
2722

2823
PyObject *m = PyModule_Create(&moduledef);
2924
if (m == NULL) {
@@ -50,10 +45,14 @@ PyInit__unytdtype_main(void)
5045
goto error;
5146
}
5247

53-
if (init_multiply_ufunc() < 0) {
48+
PyObject *numpy = init_multiply_ufunc();
49+
50+
if (numpy == NULL) {
5451
goto error;
5552
}
5653

54+
Py_DECREF(numpy);
55+
5756
return m;
5857

5958
error:

0 commit comments

Comments
 (0)