|
19 | 19 |
|
20 | 20 | from lightning_utilities.core.imports import RequirementCache |
21 | 21 | from torch import nn |
22 | | -from typing_extensions import Concatenate, ParamSpec |
| 22 | +from typing_extensions import Concatenate, ParamSpec, override |
23 | 23 |
|
24 | 24 | import lightning.pytorch as pl |
25 | 25 |
|
@@ -104,26 +104,32 @@ def _check_mixed_imports(instance: object) -> None: |
104 | 104 | _R_co = TypeVar("_R_co", covariant=True) # return type of the decorated method |
105 | 105 |
|
106 | 106 |
|
107 | | -class _restricted_classmethod_impl(Generic[_T, _R_co, _P]): |
| 107 | +class _restricted_classmethod_impl(classmethod, Generic[_T, _P, _R_co]): |
108 | 108 | """Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance |
109 | 109 | instead of a class type.""" |
110 | 110 |
|
| 111 | + method: Callable[Concatenate[type[_T], _P], _R_co] |
| 112 | + |
111 | 113 | def __init__(self, method: Callable[Concatenate[type[_T], _P], _R_co]) -> None: |
| 114 | + super().__init__(method) |
112 | 115 | self.method = method |
113 | 116 |
|
114 | | - def __get__(self, instance: Optional[_T], cls: type[_T]) -> Callable[_P, _R_co]: |
| 117 | + @override |
| 118 | + def __get__(self, instance: _T, cls: Optional[type[_T]] = None) -> Callable[_P, _R_co]: # type: ignore[override] |
115 | 119 | # The wrapper ensures that the method can be inspected, but not called on an instance |
116 | 120 | @functools.wraps(self.method) |
117 | 121 | def wrapper(*args: Any, **kwargs: Any) -> _R_co: |
118 | 122 | # Workaround for https://github.com/pytorch/pytorch/issues/67146 |
119 | 123 | is_scripting = any(os.path.join("torch", "jit") in frameinfo.filename for frameinfo in inspect.stack()) |
| 124 | + cls_type = cls if cls is not None else type(instance) |
120 | 125 | if instance is not None and not is_scripting: |
121 | 126 | raise TypeError( |
122 | | - f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be called on an instance." |
| 127 | + f"The classmethod `{cls_type.__name__}.{self.method.__name__}` cannot be called on an instance." |
123 | 128 | " Please call it on the class type and make sure the return value is used." |
124 | 129 | ) |
125 | | - return self.method(cls, *args, **kwargs) |
| 130 | + return self.method(cls_type, *args, **kwargs) |
126 | 131 |
|
| 132 | + wrapper.__func__ = self.method |
127 | 133 | return wrapper |
128 | 134 |
|
129 | 135 |
|
|
0 commit comments