| 
19 | 19 | 
 
  | 
20 | 20 | from lightning_utilities.core.imports import RequirementCache  | 
21 | 21 | from torch import nn  | 
22 |  | -from typing_extensions import Concatenate, ParamSpec  | 
 | 22 | +from typing_extensions import Concatenate, ParamSpec, override  | 
23 | 23 | 
 
  | 
24 | 24 | import lightning.pytorch as pl  | 
25 | 25 | 
 
  | 
@@ -104,26 +104,32 @@ def _check_mixed_imports(instance: object) -> None:  | 
104 | 104 | _R_co = TypeVar("_R_co", covariant=True)  # return type of the decorated method  | 
105 | 105 | 
 
  | 
106 | 106 | 
 
  | 
107 |  | -class _restricted_classmethod_impl(Generic[_T, _R_co, _P]):  | 
 | 107 | +class _restricted_classmethod_impl(classmethod, Generic[_T, _P, _R_co]):  | 
108 | 108 |     """Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance  | 
109 | 109 |     instead of a class type."""  | 
110 | 110 | 
 
  | 
 | 111 | +    method: Callable[Concatenate[type[_T], _P], _R_co]  | 
 | 112 | + | 
111 | 113 |     def __init__(self, method: Callable[Concatenate[type[_T], _P], _R_co]) -> None:  | 
 | 114 | +        super().__init__(method)  | 
112 | 115 |         self.method = method  | 
113 | 116 | 
 
  | 
114 |  | -    def __get__(self, instance: Optional[_T], cls: type[_T]) -> Callable[_P, _R_co]:  | 
 | 117 | +    @override  | 
 | 118 | +    def __get__(self, instance: _T, cls: Optional[type[_T]] = None) -> Callable[_P, _R_co]:  # type: ignore[override]  | 
115 | 119 |         # The wrapper ensures that the method can be inspected, but not called on an instance  | 
116 | 120 |         @functools.wraps(self.method)  | 
117 | 121 |         def wrapper(*args: Any, **kwargs: Any) -> _R_co:  | 
118 | 122 |             # Workaround for https://github.com/pytorch/pytorch/issues/67146  | 
119 | 123 |             is_scripting = any(os.path.join("torch", "jit") in frameinfo.filename for frameinfo in inspect.stack())  | 
 | 124 | +            cls_type = cls if cls is not None else type(instance)  | 
120 | 125 |             if instance is not None and not is_scripting:  | 
121 | 126 |                 raise TypeError(  | 
122 |  | -                    f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be called on an instance."  | 
 | 127 | +                    f"The classmethod `{cls_type.__name__}.{self.method.__name__}` cannot be called on an instance."  | 
123 | 128 |                     " Please call it on the class type and make sure the return value is used."  | 
124 | 129 |                 )  | 
125 |  | -            return self.method(cls, *args, **kwargs)  | 
 | 130 | +            return self.method(cls_type, *args, **kwargs)  | 
126 | 131 | 
 
  | 
 | 132 | +        wrapper.__func__ = self.method  | 
127 | 133 |         return wrapper  | 
128 | 134 | 
 
  | 
129 | 135 | 
 
  | 
 | 
0 commit comments