Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/lightning/pytorch/utilities/model_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from lightning_utilities.core.imports import RequirementCache
from torch import nn
from typing_extensions import Concatenate, ParamSpec
from typing_extensions import Concatenate, ParamSpec, override

import lightning.pytorch as pl

Expand Down Expand Up @@ -104,26 +104,32 @@ def _check_mixed_imports(instance: object) -> None:
_R_co = TypeVar("_R_co", covariant=True) # return type of the decorated method


class _restricted_classmethod_impl(Generic[_T, _R_co, _P]):
class _restricted_classmethod_impl(classmethod, Generic[_T, _P, _R_co]):
"""Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance
instead of a class type."""

method: Callable[Concatenate[type[_T], _P], _R_co]

def __init__(self, method: Callable[Concatenate[type[_T], _P], _R_co]) -> None:
super().__init__(method)
self.method = method

def __get__(self, instance: Optional[_T], cls: type[_T]) -> Callable[_P, _R_co]:
@override
def __get__(self, instance: _T, cls: Optional[type[_T]] = None) -> Callable[_P, _R_co]: # type: ignore[override]
# The wrapper ensures that the method can be inspected, but not called on an instance
@functools.wraps(self.method)
def wrapper(*args: Any, **kwargs: Any) -> _R_co:
# Workaround for https://github.com/pytorch/pytorch/issues/67146
is_scripting = any(os.path.join("torch", "jit") in frameinfo.filename for frameinfo in inspect.stack())
cls_type = cls if cls is not None else type(instance)
if instance is not None and not is_scripting:
raise TypeError(
f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be called on an instance."
f"The classmethod `{cls_type.__name__}.{self.method.__name__}` cannot be called on an instance."
" Please call it on the class type and make sure the return value is used."
)
return self.method(cls, *args, **kwargs)
return self.method(cls_type, *args, **kwargs)

wrapper.__func__ = self.method
return wrapper


Expand Down
Loading