Skip to content

Commit 0a756c0

Browse files
mtsokolpre-commit-ci[bot]rwgk
authored
MAINT: Include numpy._core imports (pybind#4857)
* MAINT: Include numpy._core imports * style: pre-commit fixes * Apply review comments * style: pre-commit fixes * Add no-inline attribute * Select submodule name based on numpy version * style: pre-commit fixes * Update pre-commit check * Add error_already_set and simplify if statement * Update .pre-commit-config.yaml Co-authored-by: Ralf W. Grosse-Kunstleve <[email protected]> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ralf W. Grosse-Kunstleve <[email protected]>
1 parent f468b07 commit 0a756c0

File tree

2 files changed

+22
-7
lines changed

2 files changed

+22
-7
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ repos:
142142
- id: disallow-caps
143143
name: Disallow improper capitalization
144144
language: pygrep
145-
entry: PyBind|Numpy|Cmake|CCache|PyTest
145+
entry: PyBind|\bNumpy\b|Cmake|CCache|PyTest
146146
exclude: ^\.pre-commit-config.yaml$
147147

148148
# PyLint has native support - not always usable, but works for us

include/pybind11/numpy.h

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,20 @@ inline numpy_internals &get_numpy_internals() {
120120
return *ptr;
121121
}
122122

123+
PYBIND11_NOINLINE module_ import_numpy_core_submodule(const char *submodule_name) {
124+
module_ numpy = module_::import("numpy");
125+
str version_string = numpy.attr("__version__");
126+
127+
module_ numpy_lib = module_::import("numpy.lib");
128+
object numpy_version = numpy_lib.attr("NumpyVersion")(version_string);
129+
int major_version = numpy_version.attr("major").cast<int>();
130+
131+
/* `numpy.core` was renamed to `numpy._core` in NumPy 2.0 as it officially
132+
became a private module. */
133+
std::string numpy_core_path = major_version >= 2 ? "numpy._core" : "numpy.core";
134+
return module_::import((numpy_core_path + "." + submodule_name).c_str());
135+
}
136+
123137
template <typename T>
124138
struct same_size {
125139
template <typename U>
@@ -263,9 +277,13 @@ struct npy_api {
263277
};
264278

265279
static npy_api lookup() {
266-
module_ m = module_::import("numpy.core.multiarray");
280+
module_ m = detail::import_numpy_core_submodule("multiarray");
267281
auto c = m.attr("_ARRAY_API");
268282
void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), nullptr);
283+
if (api_ptr == nullptr) {
284+
raise_from(PyExc_SystemError, "FAILURE obtaining numpy _ARRAY_API pointer.");
285+
throw error_already_set();
286+
}
269287
npy_api api;
270288
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
271289
DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
@@ -626,11 +644,8 @@ class dtype : public object {
626644

627645
private:
628646
static object _dtype_from_pep3118() {
629-
static PyObject *obj = module_::import("numpy.core._internal")
630-
.attr("_dtype_from_pep3118")
631-
.cast<object>()
632-
.release()
633-
.ptr();
647+
module_ m = detail::import_numpy_core_submodule("_internal");
648+
static PyObject *obj = m.attr("_dtype_from_pep3118").cast<object>().release().ptr();
634649
return reinterpret_borrow<object>(obj);
635650
}
636651

0 commit comments

Comments
 (0)