|
19 | 19 |
|
20 | 20 | from lightning_utilities.core.imports import RequirementCache
|
21 | 21 | from torch import nn
|
22 |
| -from typing_extensions import Concatenate, ParamSpec |
| 22 | +from typing_extensions import Concatenate, ParamSpec, override |
23 | 23 |
|
24 | 24 | import lightning.pytorch as pl
|
25 | 25 |
|
@@ -104,23 +104,30 @@ def _check_mixed_imports(instance: object) -> None:
|
104 | 104 | _R_co = TypeVar("_R_co", covariant=True) # return type of the decorated method
|
105 | 105 |
|
106 | 106 |
|
107 |
| -def _restricted_classmethod_impl(method: Callable[Concatenate[type[_T], _P], _R_co]) -> classmethod: |
| 107 | +class _restricted_classmethod_impl(classmethod): |
108 | 108 | """Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance
|
109 | 109 | instead of a class type."""
|
110 | 110 |
|
111 |
| - # The wrapper ensures that the method can be inspected, but not called on an instance |
112 |
| - @functools.wraps(method) |
113 |
| - def wrapper(cls: type[_T], *args: _P.args, **kwargs: _P.kwargs) -> _R_co: |
114 |
| - # Workaround for https://github.com/pytorch/pytorch/issues/67146 |
115 |
| - is_scripting = any(os.path.join("torch", "jit") in frameinfo.filename for frameinfo in inspect.stack()) |
116 |
| - if inspect.isclass(cls) and not is_scripting: |
117 |
| - raise TypeError( |
118 |
| - f"The classmethod `{cls.__name__}.{method.__name__}` cannot be called on an instance." |
119 |
| - " Please call it on the class type and make sure the return value is used." |
120 |
| - ) |
121 |
| - return method(cls, *args, **kwargs) |
122 |
| - |
123 |
| - return classmethod(wrapper) |
| 111 | + def __init__(self, method: Callable[Concatenate[type[_T], _P], _R_co]) -> None: |
| 112 | + super().__init__(method) |
| 113 | + self.method = method |
| 114 | + |
| 115 | + @override |
| 116 | + def __get__(self, instance: Optional[_T], cls: Optional[type[_T]] = None) -> Callable[_P, _R_co]: |
| 117 | + # The wrapper ensures that the method can be inspected, but not called on an instance |
| 118 | + @functools.wraps(self.method) |
| 119 | + def wrapper(*args: Any, **kwargs: Any) -> _R_co: |
| 120 | + # Workaround for https://github.com/pytorch/pytorch/issues/67146 |
| 121 | + is_scripting = any(os.path.join("torch", "jit") in frameinfo.filename for frameinfo in inspect.stack()) |
| 122 | + cls_type = cls if cls is not None else type(instance) |
| 123 | + if instance is not None and not is_scripting: |
| 124 | + raise TypeError( |
| 125 | + f"The classmethod `{cls_type.__name__}.{self.method.__name__}` cannot be called on an instance." |
| 126 | + " Please call it on the class type and make sure the return value is used." |
| 127 | + ) |
| 128 | + return self.method(cls_type, *args, **kwargs) |
| 129 | + |
| 130 | + return wrapper |
124 | 131 |
|
125 | 132 |
|
126 | 133 | if TYPE_CHECKING:
|
|
0 commit comments