diff --git a/src/lightning/pytorch/utilities/model_helpers.py b/src/lightning/pytorch/utilities/model_helpers.py index daf1c400c03df..15a0d96e383c0 100644 --- a/src/lightning/pytorch/utilities/model_helpers.py +++ b/src/lightning/pytorch/utilities/model_helpers.py @@ -19,7 +19,7 @@ from lightning_utilities.core.imports import RequirementCache from torch import nn -from typing_extensions import Concatenate, ParamSpec +from typing_extensions import Concatenate, ParamSpec, override import lightning.pytorch as pl @@ -104,26 +104,32 @@ def _check_mixed_imports(instance: object) -> None: _R_co = TypeVar("_R_co", covariant=True) # return type of the decorated method -class _restricted_classmethod_impl(Generic[_T, _R_co, _P]): +class _restricted_classmethod_impl(classmethod, Generic[_T, _P, _R_co]): """Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance instead of a class type.""" + method: Callable[Concatenate[type[_T], _P], _R_co] + def __init__(self, method: Callable[Concatenate[type[_T], _P], _R_co]) -> None: + super().__init__(method) self.method = method - def __get__(self, instance: Optional[_T], cls: type[_T]) -> Callable[_P, _R_co]: + @override + def __get__(self, instance: _T, cls: Optional[type[_T]] = None) -> Callable[_P, _R_co]: # type: ignore[override] # The wrapper ensures that the method can be inspected, but not called on an instance @functools.wraps(self.method) def wrapper(*args: Any, **kwargs: Any) -> _R_co: # Workaround for https://github.com/pytorch/pytorch/issues/67146 is_scripting = any(os.path.join("torch", "jit") in frameinfo.filename for frameinfo in inspect.stack()) + cls_type = cls if cls is not None else type(instance) if instance is not None and not is_scripting: raise TypeError( - f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be called on an instance." + f"The classmethod `{cls_type.__name__}.{self.method.__name__}` cannot be called on an instance." " Please call it on the class type and make sure the return value is used." ) - return self.method(cls, *args, **kwargs) + return self.method(cls_type, *args, **kwargs) + wrapper.__func__ = self.method return wrapper