|
15 | 15 | import inspect
|
16 | 16 | import logging
|
17 | 17 | import os
|
18 |
| -from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar |
| 18 | +from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar |
19 | 19 |
|
20 | 20 | from lightning_utilities.core.imports import RequirementCache
|
21 | 21 | from torch import nn
|
@@ -104,16 +104,18 @@ def _check_mixed_imports(instance: object) -> None:
|
104 | 104 | _R_co = TypeVar("_R_co", covariant=True) # return type of the decorated method
|
105 | 105 |
|
106 | 106 |
|
107 |
| -class _restricted_classmethod_impl(classmethod): |
| 107 | +class _restricted_classmethod_impl(classmethod, Generic[_T, _P, _R_co]): |
108 | 108 | """Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance
|
109 | 109 | instead of a class type."""
|
110 | 110 |
|
| 111 | + method: Callable[Concatenate[type[_T], _P], _R_co] |
| 112 | + |
111 | 113 | def __init__(self, method: Callable[Concatenate[type[_T], _P], _R_co]) -> None:
|
112 | 114 | super().__init__(method)
|
113 | 115 | self.method = method
|
114 | 116 |
|
115 | 117 | @override
|
116 |
| - def __get__(self, instance: Optional[_T], cls: Optional[type[_T]] = None) -> Callable[_P, _R_co]: |
| 118 | + def __get__(self, instance: _T, cls: Optional[type[_T]] = None) -> Callable[_P, _R_co]: # type: ignore[override] |
117 | 119 | # The wrapper ensures that the method can be inspected, but not called on an instance
|
118 | 120 | @functools.wraps(self.method)
|
119 | 121 | def wrapper(*args: Any, **kwargs: Any) -> _R_co:
|
|
0 commit comments