Skip to content

Commit 405df4a

Browse files
committed
refactored _restricted_classmethod_impl
1 parent 8eff1f4 commit 405df4a

File tree

1 file changed

+14
-20
lines changed

1 file changed

+14
-20
lines changed

src/lightning/pytorch/utilities/model_helpers.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -104,29 +104,23 @@ def _check_mixed_imports(instance: object) -> None:
104104
_R_co = TypeVar("_R_co", covariant=True) # return type of the decorated method
105105

106106

107-
class _restricted_classmethod_impl(classmethod):
107+
def _restricted_classmethod_impl(method: Callable[Concatenate[type[_T], _P], _R_co]) -> classmethod:
108108
"""Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance
109109
instead of a class type."""
110110

111-
def __init__(self, method: Callable[Concatenate[type[_T], _P], _R_co]) -> None:
112-
super().__init__(method)
113-
self.method = method
114-
115-
def __get__(self, instance: _T, cls: Optional[type[_T]] = None) -> Callable[_P, _R_co]:
116-
# The wrapper ensures that the method can be inspected, but not called on an instance
117-
@functools.wraps(self.method)
118-
def wrapper(*args: Any, **kwargs: Any) -> _R_co:
119-
# Workaround for https://github.com/pytorch/pytorch/issues/67146
120-
is_scripting = any(os.path.join("torch", "jit") in frameinfo.filename for frameinfo in inspect.stack())
121-
cls_type = cls if cls is not None else type(instance)
122-
if instance is not None and not is_scripting:
123-
raise TypeError(
124-
f"The classmethod `{cls_type.__name__}.{self.method.__name__}` cannot be called on an instance."
125-
" Please call it on the class type and make sure the return value is used."
126-
)
127-
return self.method(cls_type, *args, **kwargs)
128-
129-
return wrapper
111+
# The wrapper ensures that the method can be inspected, but not called on an instance
112+
@functools.wraps(method)
113+
def wrapper(cls: type[_T], *args: _P.args, **kwargs: _P.kwargs) -> _R_co:
114+
# Workaround for https://github.com/pytorch/pytorch/issues/67146
115+
is_scripting = any(os.path.join("torch", "jit") in frameinfo.filename for frameinfo in inspect.stack())
116+
if inspect.isclass(cls) and not is_scripting:
117+
raise TypeError(
118+
f"The classmethod `{cls.__name__}.{method.__name__}` cannot be called on an instance."
119+
" Please call it on the class type and make sure the return value is used."
120+
)
121+
return method(cls, *args, **kwargs)
122+
123+
return classmethod(wrapper)
130124

131125

132126
if TYPE_CHECKING:

0 commit comments

Comments
 (0)