Skip to content

Commit 8a7f9b3

Browse files
committed
[API-Compat] Make the forbid_keywords decorator transparent
1 parent 97f1d5b commit 8a7f9b3

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

python/paddle/utils/compat_kwarg_check.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import Any, Callable, TypeVar
17+
import functools
18+
import inspect
19+
from typing import Any, Callable, TypeVar, cast
1820

1921
F = TypeVar('F', bound=Callable[..., Any])
2022

@@ -29,10 +31,12 @@ def forbid_keywords(
2931
illegal_keys: list[str] | str - Forbidden keyword names
3032
correct_func_name: str - Recommended function name
3133
"""
32-
if isinstance(illegal_keys, str):
33-
illegal_keys = [illegal_keys]
34+
keys = [illegal_keys] if isinstance(illegal_keys, str) else illegal_keys
3435

3536
def decorator(func: F) -> F:
37+
orig_sig = inspect.signature(func)
38+
39+
@functools.wraps(func)
3640
def wrapper(*args: Any, **kwargs: Any) -> Any:
3741
found_keys = [key for key in illegal_keys if key in kwargs]
3842

@@ -47,6 +51,14 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
4751

4852
return func(*args, **kwargs)
4953

50-
return wrapper
54+
# Important: function signatures / specs should be copied to avoid erroneous input/output extraction (particularly in static graph, like test_split_op.py)
55+
wrapper.__signature__ = orig_sig
56+
if hasattr(func, "__defaults__"):
57+
wrapper.__defaults__ = func.__defaults__
58+
if hasattr(func, "__kwdefaults__"):
59+
wrapper.__kwdefaults__ = func.__kwdefaults__
60+
wrapper.__wrapped__ = func
61+
62+
return cast('F', wrapper)
5163

5264
return decorator

0 commit comments

Comments
 (0)