Skip to content

Commit fe75549

Browse files
committed
Merge branch 'main' into cache_helpers
2 parents fc6b56b + 205c967 commit fe75549

File tree

17 files changed

+1048
-467
lines changed

17 files changed

+1048
-467
lines changed

array_api_compat/_internal.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,16 @@
22
Internal helpers
33
"""
44

5+
from collections.abc import Callable
56
from functools import wraps
67
from inspect import signature
8+
from types import ModuleType
9+
from typing import TypeVar
710

8-
def get_xp(xp):
11+
_T = TypeVar("_T")
12+
13+
14+
def get_xp(xp: ModuleType) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
915
"""
1016
Decorator to automatically replace xp with the corresponding array module.
1117
@@ -22,14 +28,14 @@ def func(x, /, xp, kwarg=None):
2228
2329
"""
2430

25-
def inner(f):
31+
def inner(f: Callable[..., _T], /) -> Callable[..., _T]:
2632
@wraps(f)
27-
def wrapped_f(*args, **kwargs):
33+
def wrapped_f(*args: object, **kwargs: object) -> object:
2834
return f(*args, xp=xp, **kwargs)
2935

3036
sig = signature(f)
3137
new_sig = sig.replace(
32-
parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"]
38+
parameters=[par for i, par in sig.parameters.items() if i != "xp"]
3339
)
3440

3541
if wrapped_f.__doc__ is None:
@@ -40,7 +46,14 @@ def wrapped_f(*args, **kwargs):
4046
specification for more details.
4147
4248
"""
43-
wrapped_f.__signature__ = new_sig
44-
return wrapped_f
49+
wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue]
50+
return wrapped_f # pyright: ignore[reportReturnType]
4551

4652
return inner
53+
54+
55+
__all__ = ["get_xp"]
56+
57+
58+
def __dir__() -> list[str]:
59+
return __all__

array_api_compat/common/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from ._helpers import * # noqa: F403
1+
from ._helpers import * # noqa: F403

0 commit comments

Comments
 (0)