15
15
import contextlib
16
16
import warnings
17
17
from functools import wraps
18
- from typing import TYPE_CHECKING , Any , Callable , Dict , List , Mapping , Optional
18
+ from typing import TYPE_CHECKING , Any , Callable , Dict , List , Mapping , Optional , TypeVar
19
19
20
20
import numpy
21
21
import torch
22
22
from frozendict import frozendict
23
+ from loguru import logger
23
24
from transformers import AutoConfig
24
25
25
26
27
+ T = TypeVar ("T" , bound = "Callable" ) # used by `deprecated`
28
+
29
+
26
30
if TYPE_CHECKING :
27
31
from compressed_tensors .compressors import ModelCompressor
28
32
@@ -170,15 +174,17 @@ def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any:
170
174
return res
171
175
172
176
173
- def deprecated (future_name : Optional [str ] = None , message : Optional [str ] = None ):
177
+ def deprecated (
178
+ future_name : Optional [str ] = None , message : Optional [str ] = None
179
+ ) -> Callable [[T ], T ]:
174
180
"""
175
181
Decorator to mark functions as deprecated
176
182
177
183
:param new_function: Function called in place of deprecated function
178
184
:param message: Deprecation message, replaces default deprecation message
179
185
"""
180
186
181
- def decorator (func : Callable [[ Any ], Any ]) :
187
+ def decorator (func : T ) -> T :
182
188
nonlocal message
183
189
184
190
if message is None :
@@ -190,7 +196,7 @@ def decorator(func: Callable[[Any], Any]):
190
196
191
197
@wraps (func )
192
198
def wrapped (* args , ** kwargs ):
193
- warnings . warn ( message , DeprecationWarning , stacklevel = 2 )
199
+ logger . bind ( log_once = True ). warning ( message )
194
200
return func (* args , ** kwargs )
195
201
196
202
return wrapped
0 commit comments