Skip to content

Commit 9188aa6

Browse files
authored
Backwards compatibility with NumPy v1 (#835)
1 parent 1c0172c commit 9188aa6

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

src/valor_lite/classification/computation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
22
from numpy.typing import NDArray
33

4+
import valor_lite.classification.numpy_compatibility as npc
5+
46

57
def _compute_rocauc(
68
data: NDArray[np.float64],
@@ -56,7 +58,7 @@ def _compute_rocauc(
5658
np.maximum.accumulate(tpr, axis=1, out=tpr)
5759

5860
# compute rocauc
59-
rocauc = np.trapezoid(x=fpr, y=tpr, axis=1)
61+
rocauc = npc.trapezoid(x=fpr, y=tpr, axis=1)
6062

6163
# compute mean rocauc
6264
mean_rocauc = rocauc.mean()
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import numpy as np
2+
from numpy.typing import NDArray
3+
4+
try:
5+
_numpy_trapezoid = np.trapezoid # numpy v2
6+
except AttributeError:
7+
_numpy_trapezoid = np.trapz # numpy v1
8+
9+
10+
def trapezoid(
11+
x: NDArray[np.float64], y: NDArray[np.float64], axis: int
12+
) -> NDArray[np.float64]:
13+
return _numpy_trapezoid(x=x, y=y, axis=axis) # type: ignore - NumPy compatibility

0 commit comments

Comments
 (0)