Skip to content

Commit a1d07eb

Browse files
ArmavicaricardoV94
authored andcommitted
Remove npy<2 compatibility for NPY_RAVEL_AXIS value
1 parent 05f6985 commit a1d07eb

File tree

4 files changed

+13
-17
lines changed

4 files changed

+13
-17
lines changed

pytensor/npy_2_compat.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,6 @@
1616
ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined]
1717

1818

19-
# to patch up some of the C code, we need to use these special values...
20-
if using_numpy_2:
21-
numpy_axis_is_none_flag = np.iinfo(np.int32).min # the value of "NPY_RAVEL_AXIS"
22-
else:
23-
# 32 is the value used to mark axis = None in Numpy C-API prior to version 2.0
24-
numpy_axis_is_none_flag = 32
25-
26-
2719
# function that replicates np.unique from numpy < 2.0
2820
def old_np_unique(
2921
arr, return_index=False, return_inverse=False, return_counts=False, axis=None

pytensor/tensor/extra_ops.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,7 @@
1818
from pytensor.link.c.op import COp
1919
from pytensor.link.c.params_type import ParamsType
2020
from pytensor.link.c.type import EnumList, Generic
21-
from pytensor.npy_2_compat import (
22-
npy_2_compat_header,
23-
numpy_axis_is_none_flag,
24-
old_np_unique,
25-
)
21+
from pytensor.npy_2_compat import npy_2_compat_header, old_np_unique
2622
from pytensor.raise_op import Assert
2723
from pytensor.scalar import int64 as int_t
2824
from pytensor.scalar import upcast
@@ -51,7 +47,7 @@
5147
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
5248
from pytensor.tensor.utils import normalize_reduce_axis
5349
from pytensor.tensor.variable import TensorVariable
54-
from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH
50+
from pytensor.utils import LOCAL_BITWIDTH, NPY_RAVEL_AXIS, PYTHON_INT_BITWIDTH
5551

5652

5753
class CpuContiguous(COp):
@@ -308,7 +304,7 @@ def __init__(self, axis: int | None = None, mode="add"):
308304
@property
309305
def c_axis(self) -> int:
310306
if self.axis is None:
311-
return numpy_axis_is_none_flag
307+
return NPY_RAVEL_AXIS
312308
return self.axis
313309

314310
def make_node(self, x):

pytensor/tensor/math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pytensor.graph.replace import _vectorize_node
1515
from pytensor.link.c.op import COp
1616
from pytensor.link.c.params_type import ParamsType
17-
from pytensor.npy_2_compat import npy_2_compat_header, numpy_axis_is_none_flag
17+
from pytensor.npy_2_compat import npy_2_compat_header
1818
from pytensor.printing import pprint
1919
from pytensor.raise_op import Assert
2020
from pytensor.scalar.basic import BinaryScalarOp
@@ -162,7 +162,7 @@ def get_params(self, node):
162162
c_axis = np.int64(self.axis[0])
163163
else:
164164
# The value here doesn't matter, it won't be used
165-
c_axis = numpy_axis_is_none_flag
165+
c_axis = 0
166166
return self.params_type.get_params(c_axis=c_axis)
167167

168168
def make_node(self, x):

pytensor/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from functools import partial
1111
from pathlib import Path
1212

13+
import numpy as np
14+
1315

1416
__all__ = [
1517
"get_unbound_function",
@@ -19,6 +21,7 @@
1921
"output_subprocess_Popen",
2022
"LOCAL_BITWIDTH",
2123
"PYTHON_INT_BITWIDTH",
24+
"NPY_RAVEL_AXIS",
2225
"NoDuplicateOptWarningFilter",
2326
]
2427

@@ -46,6 +49,11 @@
4649
'l' denotes a C long int, and the size is expressed in bytes.
4750
"""
4851

52+
NPY_RAVEL_AXIS = np.iinfo(np.int32).min
53+
"""
54+
The value of the numpy C API NPY_RAVEL_AXIS.
55+
"""
56+
4957

5058
def __call_excepthooks(type, value, trace):
5159
"""

0 commit comments

Comments
 (0)