Skip to content

Commit bca9c0c

Browse files
jorenhamcrusaderky
andcommitted
TYP: apply review suggestions
Co-authored-by: crusaderky <[email protected]>
1 parent a522dbc commit bca9c0c

File tree

4 files changed

+25
-17
lines changed

4 files changed

+25
-17
lines changed

array_api_compat/_internal.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,18 @@
22
Internal helpers
33
"""
44

5+
from collections.abc import Callable
56
from functools import wraps
67
from inspect import signature
7-
from typing import TYPE_CHECKING
8+
from types import ModuleType
9+
from typing import TypeVar
810

911
__all__ = ["get_xp"]
1012

11-
if TYPE_CHECKING:
12-
from collections.abc import Callable
13-
from types import ModuleType
14-
from typing import TypeVar
13+
_T = TypeVar("_T")
1514

16-
_T = TypeVar("_T")
1715

18-
19-
def get_xp(xp: "ModuleType") -> "Callable[[Callable[..., _T]], Callable[..., _T]]":
16+
def get_xp(xp: ModuleType) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
2017
"""
2118
Decorator to automatically replace xp with the corresponding array module.
2219
@@ -33,7 +30,7 @@ def func(x, /, xp, kwarg=None):
3330
3431
"""
3532

36-
def inner(f: "Callable[..., _T]", /) -> "Callable[..., _T]":
33+
def inner(f: Callable[..., _T], /) -> Callable[..., _T]:
3734
@wraps(f)
3835
def wrapped_f(*args: object, **kwargs: object) -> object:
3936
return f(*args, xp=xp, **kwargs)

array_api_compat/common/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ._typing import Array, Device, DType, Namespace
1414

1515
if TYPE_CHECKING:
16+
# TODO: import from typing (requires Python >=3.13)
1617
from typing_extensions import TypeIs
1718

1819
# These functions are modified from the NumPy versions.

array_api_compat/common/_helpers.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,22 @@
1212
import math
1313
import sys
1414
import warnings
15-
from typing import TYPE_CHECKING, Any, Literal, SupportsIndex, cast, overload
15+
from collections.abc import Collection
16+
from typing import (
17+
TYPE_CHECKING,
18+
Any,
19+
Literal,
20+
SupportsIndex,
21+
TypeAlias,
22+
TypeGuard,
23+
TypeVar,
24+
cast,
25+
overload,
26+
)
1627

1728
from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace
1829

1930
if TYPE_CHECKING:
20-
from collections.abc import Collection
2131

2232
import dask.array as da
2333
import jax
@@ -26,7 +36,9 @@
2636
import numpy.typing as npt
2737
import sparse # pyright: ignore[reportMissingTypeStubs]
2838
import torch
29-
from typing_extensions import TypeAlias, TypeGuard, TypeIs, TypeVar
39+
40+
# TODO: import from typing (requires Python >=3.13)
41+
from typing_extensions import TypeIs, TypeVar
3042

3143
_SizeT = TypeVar("_SizeT", bound=int | None)
3244

array_api_compat/numpy/_aliases.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from __future__ import annotations
33

44
from builtins import bool as py_bool
5-
from typing import TYPE_CHECKING, cast
5+
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast
66

77
import numpy as np
88

@@ -13,11 +13,9 @@
1313
from ._typing import Array, Device, DType
1414

1515
if TYPE_CHECKING:
16-
from typing import Any, Literal
16+
from typing_extensions import Buffer, TypeIs
1717

18-
from typing_extensions import Buffer, TypeAlias, TypeIs
19-
20-
_Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode
18+
_Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode
2119

2220
bool = np.bool_
2321

0 commit comments

Comments
 (0)