2
2
Internal helpers
3
3
"""
4
4
5
+ from collections .abc import Callable
5
6
from functools import wraps
6
7
from inspect import signature
8
+ from types import ModuleType
9
+ from typing import TypeVar
7
10
8
- def get_xp (xp ):
11
+ _T = TypeVar ("_T" )
12
+
13
+
14
+ def get_xp (xp : ModuleType ) -> Callable [[Callable [..., _T ]], Callable [..., _T ]]:
9
15
"""
10
16
Decorator to automatically replace xp with the corresponding array module.
11
17
@@ -22,14 +28,14 @@ def func(x, /, xp, kwarg=None):
22
28
23
29
"""
24
30
25
- def inner (f ) :
31
+ def inner (f : Callable [..., _T ], / ) -> Callable [..., _T ] :
26
32
@wraps (f )
27
- def wrapped_f (* args , ** kwargs ) :
33
+ def wrapped_f (* args : object , ** kwargs : object ) -> object :
28
34
return f (* args , xp = xp , ** kwargs )
29
35
30
36
sig = signature (f )
31
37
new_sig = sig .replace (
32
- parameters = [sig . parameters [ i ] for i in sig .parameters if i != "xp" ]
38
+ parameters = [par for i , par in sig .parameters . items () if i != "xp" ]
33
39
)
34
40
35
41
if wrapped_f .__doc__ is None :
@@ -40,7 +46,14 @@ def wrapped_f(*args, **kwargs):
40
46
specification for more details.
41
47
42
48
"""
43
- wrapped_f .__signature__ = new_sig
44
- return wrapped_f
49
+ wrapped_f .__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue]
50
+ return wrapped_f # pyright: ignore[reportReturnType]
45
51
46
52
return inner
53
+
54
+
55
+ __all__ = ["get_xp" ]
56
+
57
+
58
+ def __dir__ () -> list [str ]:
59
+ return __all__
0 commit comments