Skip to content

Commit f15a953

Browse files
ArmavicaricardoV94
authored andcommitted
Remove npy<2 compatibility for ndarray_c_version
1 parent a1d07eb commit f15a953

File tree

3 files changed

+5
-7
lines changed

3 files changed

+5
-7
lines changed

pytensor/link/c/basic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929
from pytensor.link.c.cmodule import get_module_cache as _get_module_cache
3030
from pytensor.link.c.interface import CLinkerObject, CLinkerOp, CLinkerType
3131
from pytensor.link.utils import gc_helper, map_storage, raise_with_op, streamline
32-
from pytensor.npy_2_compat import ndarray_c_version
33-
from pytensor.utils import difference, uniq
32+
from pytensor.utils import NDARRAY_C_VERSION, difference, uniq
3433

3534

3635
NoParams = object()
@@ -1367,7 +1366,7 @@ def cmodule_key_(
13671366

13681367
# We must always add the numpy ABI version here as
13691368
# DynamicModule always add the include <numpy/arrayobject.h>
1370-
sig.append(f"NPY_ABI_VERSION=0x{ndarray_c_version:X}")
1369+
sig.append(f"NPY_ABI_VERSION=0x{NDARRAY_C_VERSION:X}")
13711370
if c_compiler:
13721371
sig.append("c_compiler_str=" + c_compiler.version_str())
13731372

pytensor/npy_2_compat.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,6 @@
1010
using_numpy_2 = numpy_version >= "2.0.0rc1"
1111

1212

13-
if using_numpy_2:
14-
ndarray_c_version = np._core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined]
15-
else:
16-
ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined]
1713

1814

1915
# function that replicates np.unique from numpy < 2.0

pytensor/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"LOCAL_BITWIDTH",
2323
"PYTHON_INT_BITWIDTH",
2424
"NPY_RAVEL_AXIS",
25+
"NDARRAY_C_VERSION",
2526
"NoDuplicateOptWarningFilter",
2627
]
2728

@@ -54,6 +55,8 @@
5455
The value of the numpy C API NPY_RAVEL_AXIS.
5556
"""
5657

58+
NDARRAY_C_VERSION = np._core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined]
59+
5760

5861
def __call_excepthooks(type, value, trace):
5962
"""

0 commit comments

Comments
 (0)