Skip to content

Commit 891da51

Browse files
authored
[Logging] Support use of loguru (#454)
* support loguru Signed-off-by: Kyle Sayers <[email protected]> * add loguru Signed-off-by: Kyle Sayers <[email protected]> * fix quality Signed-off-by: Kyle Sayers <[email protected]> * address Signed-off-by: Kyle Sayers <[email protected]> * [Utils] Improve type hints for `deprecated`, only log once (#455) * better type hints, warn once Signed-off-by: Kyle Sayers <[email protected]> * remove unneeded import Signed-off-by: Kyle Sayers <[email protected]> * use warnings Signed-off-by: Kyle Sayers <[email protected]> --------- Signed-off-by: Kyle Sayers <[email protected]> --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent d2daa9a commit 891da51

File tree

4 files changed

+136
-5
lines changed

4 files changed

+136
-5
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def _setup_packages() -> List:
8888
)
8989

9090
def _setup_install_requires() -> List:
91-
return ["torch>=1.7.0", "transformers", "pydantic>=2.0", "frozendict"]
91+
return ["torch>=1.7.0", "transformers", "pydantic>=2.0", "frozendict", "loguru"]
9292

9393
def _setup_extras() -> Dict:
9494
return {

src/compressed_tensors/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# flake8: noqa
16+
# isort: off
17+
from .logger import LoggerConfig, configure_logger, logger
1518
from .base import *
1619

17-
# flake8: noqa
1820
from .compressors import *
1921
from .config import *
2022
from .quantization import QuantizationConfig, QuantizationStatus

src/compressed_tensors/logger.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Logger configuration for Compressed Tensors.
17+
"""
18+
19+
import os
20+
import sys
21+
from dataclasses import dataclass
22+
from typing import Any, Dict, Optional
23+
24+
from loguru import logger
25+
26+
27+
__all__ = ["LoggerConfig", "configure_logger", "logger"]
28+
29+
30+
# used by `support_log_once``
31+
_logged_once = set()
32+
33+
34+
@dataclass
35+
class LoggerConfig:
36+
disabled: bool = False
37+
clear_loggers: bool = True
38+
console_log_level: Optional[str] = "INFO"
39+
log_file: Optional[str] = None
40+
log_file_level: Optional[str] = None
41+
42+
43+
def configure_logger(config: Optional[LoggerConfig] = None):
44+
"""
45+
Configure the logger for Compressed Tensors.
46+
This function sets up the console and file logging
47+
as per the specified or default parameters.
48+
49+
Note: Environment variables take precedence over the function parameters.
50+
51+
:param config: The configuration for the logger to use.
52+
:type config: LoggerConfig
53+
"""
54+
logger_config = config or LoggerConfig()
55+
56+
# env vars get priority
57+
if bool(os.getenv("COMPRESSED_TENSORS_LOG_DISABLED")):
58+
logger_config.disabled = True
59+
if bool(os.getenv("COMPRESSED_TENSORS_CLEAR_LOGGERS")):
60+
logger_config.clear_loggers = True
61+
if (console_log_level := os.getenv("COMPRESSED_TENSORS_LOG_LEVEL")) is not None:
62+
logger_config.console_log_level = console_log_level.upper()
63+
if (log_file := os.getenv("COMPRESSED_TENSORS_LOG_FILE")) is not None:
64+
logger_config.log_file = log_file
65+
if (log_file_level := os.getenv("COMPRESSED_TENSORS_LOG_FILE_LEVEL")) is not None:
66+
logger_config.log_file_level = log_file_level.upper()
67+
68+
if logger_config.disabled:
69+
logger.disable("compressed_tensors")
70+
return
71+
72+
logger.enable("compressed_tensors")
73+
74+
if logger_config.clear_loggers:
75+
logger.remove()
76+
77+
if logger_config.console_log_level:
78+
# log as a human readable string with the time, function, level, and message
79+
logger.add(
80+
sys.stdout,
81+
level=logger_config.console_log_level.upper(),
82+
format="{time} | {function} | {level} - {message}",
83+
filter=support_log_once,
84+
)
85+
86+
if logger_config.log_file or logger_config.log_file_level:
87+
log_file = logger_config.log_file or "compressed_tensors.log"
88+
log_file_level = logger_config.log_file_level or "INFO"
89+
# log as json to the file for easier parsing
90+
logger.add(
91+
log_file,
92+
level=log_file_level.upper(),
93+
serialize=True,
94+
filter=support_log_once,
95+
)
96+
97+
98+
def support_log_once(record: Dict[str, Any]) -> bool:
99+
"""
100+
Support logging only once using `.bind(log_once=True)`
101+
102+
```
103+
logger.bind(log_once=False).info("This will log multiple times")
104+
logger.bind(log_once=False).info("This will log multiple times")
105+
logger.bind(log_once=True).info("This will only log once")
106+
logger.bind(log_once=True).info("This will only log once") # skipped
107+
```
108+
"""
109+
log_once = record["extra"].get("log_once", False)
110+
level = getattr(record["level"], "name", "none")
111+
message = hash(str(level) + record["message"])
112+
113+
if log_once and message in _logged_once:
114+
return False
115+
116+
if log_once:
117+
_logged_once.add(message)
118+
119+
return True
120+
121+
122+
# invoke logger setup on import with default values enabling console logging with INFO
123+
# and disabling file logging
124+
configure_logger(config=LoggerConfig())

src/compressed_tensors/utils/helpers.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,17 @@
1515
import contextlib
1616
import warnings
1717
from functools import wraps
18-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
18+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, TypeVar
1919

2020
import numpy
2121
import torch
2222
from frozendict import frozendict
2323
from transformers import AutoConfig
2424

2525

26+
T = TypeVar("T", bound="Callable") # used by `deprecated`
27+
28+
2629
if TYPE_CHECKING:
2730
from compressed_tensors.compressors import ModelCompressor
2831

@@ -170,15 +173,17 @@ def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any:
170173
return res
171174

172175

173-
def deprecated(future_name: Optional[str] = None, message: Optional[str] = None):
176+
def deprecated(
177+
future_name: Optional[str] = None, message: Optional[str] = None
178+
) -> Callable[[T], T]:
174179
"""
175180
Decorator to mark functions as deprecated
176181
177182
:param new_function: Function called in place of deprecated function
178183
:param message: Deprecation message, replaces default deprecation message
179184
"""
180185

181-
def decorator(func: Callable[[Any], Any]):
186+
def decorator(func: T) -> T:
182187
nonlocal message
183188

184189
if message is None:

0 commit comments

Comments
 (0)