Skip to content

Commit aa901c6

Browse files
committed
better type hints, warn once
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 52d34ee commit aa901c6

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

src/compressed_tensors/utils/helpers.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,18 @@
1515
import contextlib
1616
import warnings
1717
from functools import wraps
18-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
18+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, TypeVar
1919

2020
import numpy
2121
import torch
2222
from frozendict import frozendict
23+
from loguru import logger
2324
from transformers import AutoConfig
2425

2526

27+
T = TypeVar("T", bound="Callable") # used by `deprecated`
28+
29+
2630
if TYPE_CHECKING:
2731
from compressed_tensors.compressors import ModelCompressor
2832

@@ -170,15 +174,17 @@ def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any:
170174
return res
171175

172176

173-
def deprecated(future_name: Optional[str] = None, message: Optional[str] = None):
177+
def deprecated(
178+
future_name: Optional[str] = None, message: Optional[str] = None
179+
) -> Callable[[T], T]:
174180
"""
175181
Decorator to mark functions as deprecated
176182
177183
:param new_function: Function called in place of deprecated function
178184
:param message: Deprecation message, replaces default deprecation message
179185
"""
180186

181-
def decorator(func: Callable[[Any], Any]):
187+
def decorator(func: T) -> T:
182188
nonlocal message
183189

184190
if message is None:
@@ -190,7 +196,7 @@ def decorator(func: Callable[[Any], Any]):
190196

191197
@wraps(func)
192198
def wrapped(*args, **kwargs):
193-
warnings.warn(message, DeprecationWarning, stacklevel=2)
199+
logger.bind(log_once=True).warning(message)
194200
return func(*args, **kwargs)
195201

196202
return wrapped

0 commit comments

Comments
 (0)