Skip to content

Commit 0fb036c

Browse files
Changes to support NumPy 2.0
Use np.lib.NumpyVersion to account for change in namespace of some functions. Replaced import from numpy.core with import from numpy main namespace, where appropriate. Replace np.longcomplex with np.clongdouble. Fixed import of NumPy C-API in _pydfti extension. This follows https://numpy.org/devdocs/reference/c-api/array.html#including-and-importing-the-c-api Separate translation units, like mklfft.c, need to include arrayobject.h after defining NO_IMPORT_ARRAY preprocessor variable. The extension must be compiled with PY_ARRAY_UNIQUE_SYMBOL set to artibrary value, but same for all translation units.
1 parent 276b142 commit 0fb036c

File tree

7 files changed

+24
-20
lines changed

7 files changed

+24
-20
lines changed

mkl_fft/_float_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __downcast_float128_array(x):
5959
xdt = x.dtype
6060
if xdt == np.longdouble and not xdt == np.float64:
6161
return np.asarray(x, dtype=np.float64)
62-
elif xdt == np.longcomplex and not xdt == np.complex_:
62+
elif xdt == np.clongdouble and not xdt == np.complex_:
6363
return np.asarray(x, dtype=np.complex_)
6464
if not isinstance(x, np.ndarray):
6565
__x = np.asarray(x)

mkl_fft/_numpy_fft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
__all__ = ['fft', 'ifft', 'rfft', 'irfft', 'hfft', 'ihfft', 'rfftn',
5757
'irfftn', 'rfft2', 'irfft2', 'fft2', 'ifft2', 'fftn', 'ifftn']
5858

59-
from numpy.core import (array, asarray, asanyarray, shape, conjugate, take, sqrt, prod)
59+
from numpy import (array, asarray, asanyarray, shape, conjugate, take, sqrt, prod)
6060

6161
import numpy
6262
from . import _pydfti as mkl_fft

mkl_fft/_pydfti.pyx

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929
# imports
3030
import sys
3131
import numpy as np
32-
from numpy.core._multiarray_tests import internal_overlap
32+
if np.lib.NumpyVersion(np.__version__) >= "2.0.0a0":
33+
from numpy._core._multiarray_tests import internal_overlap
34+
else:
35+
from numpy.core._multiarray_tests import internal_overlap
3336
from threading import local as threading_local
3437

3538
# cimports
@@ -133,11 +136,6 @@ cdef extern from "src/mklfft.h":
133136
int double_cdouble_mkl_ifftnd_out(cnp.ndarray, cnp.ndarray, double)
134137
char * mkl_dfti_error(int)
135138

136-
# Initialize numpy
137-
cdef int numpy_import_status = cnp.import_array()
138-
if numpy_import_status < 0:
139-
raise ImportError("Failed to import NumPy as dependency of mkl_fft")
140-
141139

142140
cdef int _datacopied(cnp.ndarray arr, object orig):
143141
"""
@@ -217,7 +215,7 @@ cdef cnp.ndarray __process_arguments(object x, object n, object axis,
217215
cnp.NPY_ELEMENTSTRIDES | cnp.NPY_ENSUREARRAY | cnp.NPY_NOTSWAPPED,
218216
NULL)
219217

220-
if <void *> x_arr is NULL:
218+
if (<void *> x_arr) is NULL:
221219
raise ValueError("An input argument x is not an array-like object")
222220

223221
if _datacopied(x_arr, x):

mkl_fft/_scipy_fft_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from . import _float_utils
2929
import mkl
3030

31-
from numpy.core import (take, sqrt, prod)
31+
from numpy import (take, sqrt, prod)
3232
import contextvars
3333
import contextlib
3434
import operator

mkl_fft/src/mklfft.c.src

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
Copyright (c) 2017-2020, Intel Corporation
2+
Copyright (c) 2017-2024, Intel Corporation
33

44
Redistribution and use in source and binary forms, with or without
55
modification, are permitted provided that the following conditions are met:
@@ -25,9 +25,9 @@
2525
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
*/
2727

28-
#define NPY_NO_DEPRECATED_API NPY_API_VERSION
29-
28+
#define PY_SSIZE_T_CLEAN
3029
#include "Python.h"
30+
#define NO_IMPORT_ARRAY
3131
#include "numpy/arrayobject.h"
3232
#include "mklfft.h"
3333
#include "multi_iter.h"
@@ -81,13 +81,17 @@ static NPY_INLINE void get_basic_array_data(
8181
npy_intp *x_size)
8282
{
8383
npy_intp asize = 0;
84+
npy_intp elsz = 0;
85+
int x_ndim = 0;
8486
assert(x != NULL);
8587

86-
*x_rank = PyArray_NDIM(x);
88+
x_ndim = PyArray_NDIM(x);
89+
*x_rank = x_ndim;
8790
*x_shape = PyArray_SHAPE(x);
8891
*x_strides = PyArray_STRIDES(x);
89-
*x_itemsize = PyArray_ITEMSIZE(x);
90-
asize = ar_size(*x_shape, *x_rank);
92+
elsz = PyArray_ITEMSIZE(x);
93+
*x_itemsize = elsz;
94+
asize = ar_size(*x_shape, x_ndim);
9195
*x_size = asize;
9296
}
9397

mkl_fft/src/mklfft.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
*/
2727
#include "mkl.h"
28+
#include "numpy/arrayobject.h"
2829

2930
typedef struct DftiCache {
3031
DFTI_DESCRIPTOR_HANDLE hand;

setup.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2017-2023, Intel Corporation
2+
# Copyright (c) 2017-2024, Intel Corporation
33
#
44
# Redistribution and use in source and binary forms, with or without
55
# modification, are permitted provided that the following conditions are met:
@@ -49,10 +49,10 @@
4949
Programming Language :: C
5050
Programming Language :: Python
5151
Programming Language :: Python :: 3
52-
Programming Language :: Python :: 3.7
53-
Programming Language :: Python :: 3.8
5452
Programming Language :: Python :: 3.9
5553
Programming Language :: Python :: 3.10
54+
Programming Language :: Python :: 3.11
55+
Programming Language :: Python :: 3.12
5656
Programming Language :: Python :: Implementation :: CPython
5757
Topic :: Software Development
5858
Topic :: Scientific/Engineering
@@ -104,7 +104,8 @@ def extensions():
104104
extra_compile_args = [
105105
'-DNDEBUG',
106106
# '-ggdb', '-O0', '-Wall', '-Wextra', '-DDEBUG',
107-
]
107+
],
108+
define_macros=[("NPY_NO_DEPRECATED_API", None), ("PY_ARRAY_UNIQUE_SYMBOL", "mkl_fft_ext")]
108109
)
109110
]
110111

0 commit comments

Comments
 (0)