Skip to content

Commit 38bf02c

Browse files
committed
Remove npy<2 compatibility for normalize_axis_{index,tuple}
1 parent 615cfcb commit 38bf02c

File tree

13 files changed

+14
-32
lines changed

13 files changed

+14
-32
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numba
55
import numpy as np
66
from numba.core.extending import overload
7+
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
78
from numpy.lib.stride_tricks import as_strided
89

910
from pytensor.graph.op import Op
@@ -19,7 +20,6 @@
1920
store_core_outputs,
2021
)
2122
from pytensor.link.utils import compile_function_src
22-
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
2323
from pytensor.scalar.basic import (
2424
AND,
2525
OR,

pytensor/npy_2_compat.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,6 @@
33
import numpy as np
44

55

6-
# Conditional numpy imports for numpy 1.26 and 2.x compatibility
7-
try:
8-
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
9-
except ModuleNotFoundError:
10-
# numpy < 2.0
11-
from numpy.core.multiarray import normalize_axis_index # type: ignore[no-redef]
12-
from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef]
13-
14-
156
try:
167
from numpy._core.einsumfunc import ( # type: ignore[attr-defined]
178
_find_contraction,
@@ -28,8 +19,6 @@
2819
__all__ = [
2920
"_find_contraction",
3021
"_parse_einsum_input",
31-
"normalize_axis_index",
32-
"normalize_axis_tuple",
3322
]
3423

3524

pytensor/tensor/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import numpy as np
1717
from numpy.exceptions import AxisError
18+
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
1819

1920
import pytensor
2021
import pytensor.scalar.sharedvar
@@ -31,7 +32,6 @@
3132
from pytensor.graph.type import HasShape, Type
3233
from pytensor.link.c.op import COp
3334
from pytensor.link.c.params_type import ParamsType
34-
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
3535
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
3636
from pytensor.raise_op import CheckAndRaise
3737
from pytensor.scalar import int32

pytensor/tensor/blas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,10 @@
8383
from pathlib import Path
8484

8585
import numpy as np
86+
from numpy.lib.array_utils import normalize_axis_tuple
8687
from scipy.linalg import get_blas_funcs
8788

8889
from pytensor.graph import Variable, vectorize_graph
89-
from pytensor.npy_2_compat import normalize_axis_tuple
9090

9191

9292
try:

pytensor/tensor/einsum.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,10 @@
66
from typing import cast
77

88
import numpy as np
9+
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
910

1011
from pytensor.compile.builders import OpFromGraph
11-
from pytensor.npy_2_compat import (
12-
_find_contraction,
13-
_parse_einsum_input,
14-
normalize_axis_index,
15-
normalize_axis_tuple,
16-
)
12+
from pytensor.npy_2_compat import _find_contraction, _parse_einsum_input
1713
from pytensor.tensor import TensorLike
1814
from pytensor.tensor.basic import (
1915
arange,

pytensor/tensor/elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Literal
55

66
import numpy as np
7+
from numpy.lib.array_utils import normalize_axis_tuple
78

89
import pytensor.tensor.basic
910
from pytensor.configdefaults import config
@@ -16,7 +17,6 @@
1617
from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp
1718
from pytensor.link.c.params_type import ParamsType
1819
from pytensor.misc.frozendict import frozendict
19-
from pytensor.npy_2_compat import normalize_axis_tuple
2020
from pytensor.printing import Printer, pprint
2121
from pytensor.scalar import get_scalar_type
2222
from pytensor.scalar.basic import identity as scalar_identity

pytensor/tensor/extra_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections.abc import Collection, Iterable
33

44
import numpy as np
5+
from numpy.lib.array_utils import normalize_axis_index
56

67
import pytensor
78
import pytensor.scalar.basic as ps
@@ -18,7 +19,6 @@
1819
from pytensor.link.c.params_type import ParamsType
1920
from pytensor.link.c.type import EnumList, Generic
2021
from pytensor.npy_2_compat import (
21-
normalize_axis_index,
2222
npy_2_compat_header,
2323
numpy_axis_is_none_flag,
2424
old_np_unique,

pytensor/tensor/math.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import TYPE_CHECKING, Optional
66

77
import numpy as np
8+
from numpy.lib.array_utils import normalize_axis_tuple
89

910
from pytensor import config, printing
1011
from pytensor import scalar as ps
@@ -13,11 +14,7 @@
1314
from pytensor.graph.replace import _vectorize_node
1415
from pytensor.link.c.op import COp
1516
from pytensor.link.c.params_type import ParamsType
16-
from pytensor.npy_2_compat import (
17-
normalize_axis_tuple,
18-
npy_2_compat_header,
19-
numpy_axis_is_none_flag,
20-
)
17+
from pytensor.npy_2_compat import npy_2_compat_header, numpy_axis_is_none_flag
2118
from pytensor.printing import pprint
2219
from pytensor.raise_op import Assert
2320
from pytensor.scalar.basic import BinaryScalarOp

pytensor/tensor/nlinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
from typing import Literal, cast
55

66
import numpy as np
7+
from numpy.lib.array_utils import normalize_axis_tuple
78

89
from pytensor import scalar as ps
910
from pytensor.compile.builders import OpFromGraph
1011
from pytensor.gradient import DisconnectedType
1112
from pytensor.graph.basic import Apply
1213
from pytensor.graph.op import Op
13-
from pytensor.npy_2_compat import normalize_axis_tuple
1414
from pytensor.tensor import TensorLike
1515
from pytensor.tensor import basic as ptb
1616
from pytensor.tensor import math as ptm

pytensor/tensor/rewriting/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import logging
2626

2727
import numpy as np
28+
from numpy.lib.array_utils import normalize_axis_index
2829

2930
from pytensor import compile, config
3031
from pytensor.compile.ops import ViewOp
@@ -41,7 +42,6 @@
4142
)
4243
from pytensor.graph.rewriting.db import RewriteDatabase
4344
from pytensor.graph.rewriting.unify import OpPattern, OpPatternOpTypeType
44-
from pytensor.npy_2_compat import normalize_axis_index
4545
from pytensor.raise_op import Assert, CheckAndRaise, assert_op
4646
from pytensor.scalar import (
4747
AND,

0 commit comments

Comments
 (0)