Skip to content

Commit eee5c80

Browse files
committed
basic support
Signed-off-by: Kyle Sayers <[email protected]>
1 parent a21d7f9 commit eee5c80

File tree

7 files changed

+206
-74
lines changed

7 files changed

+206
-74
lines changed

src/compressed_tensors/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# flake8: noqa
1818
from .compressors import *
1919
from .config import *
20+
from .logger import LoggerConfig, configure_logger, logger
2021
from .quantization import QuantizationConfig, QuantizationStatus
2122
from .utils import *
2223
from .version import *

src/compressed_tensors/logger.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Logger configuration for Compressed Tensors.
17+
"""
18+
19+
import os
20+
import sys
21+
from dataclasses import dataclass
22+
from typing import Any, Dict, Optional
23+
24+
from loguru import logger
25+
26+
27+
__all__ = ["LoggerConfig", "configure_logger", "logger"]
28+
29+
30+
# used by `support_log_once``
31+
_logged_once = set()
32+
33+
34+
@dataclass
35+
class LoggerConfig:
36+
disabled: bool = False
37+
clear_loggers: bool = True
38+
console_log_level: Optional[str] = "INFO"
39+
log_file: Optional[str] = None
40+
log_file_level: Optional[str] = None
41+
42+
43+
def configure_logger(config: Optional[LoggerConfig] = None) -> None:
44+
"""
45+
Configure the logger for Compressed Tensors.
46+
This function sets up the console and file logging
47+
as per the specified or default parameters.
48+
49+
Note: Environment variables take precedence over the function parameters.
50+
51+
:param config: The configuration for the logger to use.
52+
:type config: LoggerConfig
53+
"""
54+
logger_config = config or LoggerConfig()
55+
56+
# env vars get priority
57+
if (disabled := os.getenv("COMPRESSED_TENSORS_LOG_DISABLED")) is not None:
58+
logger_config.disabled = disabled.lower() == "true"
59+
if (clear_loggers := os.getenv("COMPRESSED_TENSORS_CLEAR_LOGGERS")) is not None:
60+
logger_config.clear_loggers = clear_loggers.lower() == "true"
61+
if (console_log_level := os.getenv("COMPRESSED_TENSORS_LOG_LEVEL")) is not None:
62+
logger_config.console_log_level = console_log_level.upper()
63+
if (log_file := os.getenv("COMPRESSED_TENSORS_LOG_FILE")) is not None:
64+
logger_config.log_file = log_file
65+
if (log_file_level := os.getenv("COMPRESSED_TENSORS_LOG_FILE_LEVEL")) is not None:
66+
logger_config.log_file_level = log_file_level.upper()
67+
68+
if logger_config.disabled:
69+
logger.disable("compressed_tensors")
70+
return
71+
72+
logger.enable("compressed_tensors")
73+
74+
if logger_config.clear_loggers:
75+
logger.remove()
76+
77+
if logger_config.console_log_level:
78+
# log as a human readable string with the time, function, level, and message
79+
logger.add(
80+
sys.stdout,
81+
level=logger_config.console_log_level.upper(),
82+
format="{time} | {function} | {level} - {message}",
83+
filter=support_log_once,
84+
)
85+
86+
if logger_config.log_file or logger_config.log_file_level:
87+
log_file = logger_config.log_file or "compressed_tensors.log"
88+
log_file_level = logger_config.log_file_level or "INFO"
89+
# log as json to the file for easier parsing
90+
logger.add(
91+
log_file,
92+
level=log_file_level.upper(),
93+
serialize=True,
94+
filter=support_log_once,
95+
)
96+
97+
98+
def support_log_once(record: Dict[str, Any]) -> bool:
99+
"""
100+
Support logging only once using `.bind(log_once=True)`
101+
102+
```
103+
logger.bind(log_once=False).info("This will log multiple times")
104+
logger.bind(log_once=False).info("This will log multiple times")
105+
logger.bind(log_once=True).info("This will only log once")
106+
logger.bind(log_once=True).info("This will only log once") # skipped
107+
```
108+
"""
109+
log_once = record["extra"].get("log_once", False)
110+
level = getattr(record["level"], "name", "none")
111+
message = str(level) + record["message"]
112+
113+
if log_once and message in _logged_once:
114+
return False
115+
116+
if log_once:
117+
_logged_once.add(message)
118+
119+
return True
120+
121+
122+
# invoke logger setup on import with default values enabling console logging with INFO
123+
# and disabling file logging
124+
configure_logger(
125+
config=LoggerConfig(
126+
disabled=False,
127+
clear_loggers=True,
128+
console_log_level="INFO",
129+
log_file=None,
130+
log_file_level=None,
131+
)
132+
)

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
calculate_range,
3030
compute_dynamic_scales_and_zp,
3131
)
32-
from compressed_tensors.utils import safe_permute
3332
from torch.nn import Module
3433

3534

@@ -265,8 +264,7 @@ def _process_quantization(
265264
):
266265

267266
output_dtype = dtype if dtype is not None else x.dtype
268-
output = torch.zeros_like(x).to(output_dtype)
269-
columns = output.shape[-1]
267+
columns = x.size(-1)
270268

271269
# TODO: make validation step for inputs
272270

@@ -294,7 +292,7 @@ def _process_quantization(
294292
group_sizes = group_sizes[torch.argsort(group_indices)]
295293

296294
perm = torch.argsort(g_idx)
297-
x = safe_permute(x, perm, dim=1)
295+
x = x.index_select(dim=-1, index=perm)
298296

299297
# Maintain all dimensions except the last dim, which is divided by group_size
300298
reshaped_dims = (
@@ -324,11 +322,11 @@ def _process_quantization(
324322
global_scale=global_scale,
325323
)
326324

327-
output = output.flatten(start_dim=-2)
325+
output = output.flatten(-2, -1)
328326
output = output.to(output_dtype)
329327

330328
if not is_column_order:
331-
output = safe_permute(output, torch.argsort(perm), dim=1)
329+
output = output.index_select(dim=-1, index=torch.argsort(perm))
332330

333331
else: # covers channel, token and tensor strategies
334332
if do_quantize:

src/compressed_tensors/utils/helpers.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,18 @@
1515
import contextlib
1616
import warnings
1717
from functools import wraps
18-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
18+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, TypeVar
1919

2020
import numpy
2121
import torch
2222
from frozendict import frozendict
23+
from loguru import logger
2324
from transformers import AutoConfig
2425

2526

27+
T = TypeVar("T", bound="Callable") # used by `deprecated`
28+
29+
2630
if TYPE_CHECKING:
2731
from compressed_tensors.compressors import ModelCompressor
2832

@@ -170,15 +174,17 @@ def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any:
170174
return res
171175

172176

173-
def deprecated(future_name: Optional[str] = None, message: Optional[str] = None):
177+
def deprecated(
178+
future_name: Optional[str] = None, message: Optional[str] = None
179+
) -> Callable[[T], T]:
174180
"""
175181
Decorator to mark functions as deprecated
176182
177183
:param new_function: Function called in place of deprecated function
178184
:param message: Deprecation message, replaces default deprecation message
179185
"""
180186

181-
def decorator(func: Callable[[Any], Any]):
187+
def decorator(func: T) -> T:
182188
nonlocal message
183189

184190
if message is None:
@@ -190,7 +196,7 @@ def decorator(func: Callable[[Any], Any]):
190196

191197
@wraps(func)
192198
def wrapped(*args, **kwargs):
193-
warnings.warn(message, DeprecationWarning, stacklevel=2)
199+
logger.bind(log_once=True).warning(message)
194200
return func(*args, **kwargs)
195201

196202
return wrapped

src/compressed_tensors/utils/permute.py

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Set, Tuple
16-
1715
import torch
16+
from compressed_tensors.utils.helpers import deprecated
1817

1918

2019
__all__ = ["safe_permute"]
2120

2221

23-
# these datatypes are missing implementations required for standard permutation
24-
_EXPERIMENTAL_DTYPES: Set[Tuple[torch.dtype, torch.device]] = set()
25-
26-
22+
@deprecated("Tensor.index_select")
2723
def safe_permute(value: torch.Tensor, perm: torch.Tensor, dim: int = 0) -> torch.Tensor:
2824
"""
2925
Perform out-of-place permutation without using torch.Tensor.index_put_,
@@ -34,37 +30,4 @@ def safe_permute(value: torch.Tensor, perm: torch.Tensor, dim: int = 0) -> torch
3430
:param dim: dimension along which to apply permutation
3531
:return: permuted value
3632
"""
37-
dtype_tuple = (value.dtype, value.device)
38-
39-
if dtype_tuple in _EXPERIMENTAL_DTYPES:
40-
return _fallback_permute(value, perm, dim)
41-
42-
try:
43-
return value[tuple([slice(None)] * dim + [perm])]
44-
except RuntimeError:
45-
# Mark dtype as experimental if advanced indexing fails
46-
_EXPERIMENTAL_DTYPES.add(dtype_tuple)
47-
return _fallback_permute(value, perm, dim)
48-
49-
50-
def _fallback_permute(
51-
value: torch.Tensor, perm: torch.Tensor, dim: int
52-
) -> torch.Tensor:
53-
"""
54-
Fallback permutation method for experimental dtypes.
55-
56-
:param value: tensor to permute
57-
:param perm: permutation map
58-
:param dim: dimension along which to apply permutation
59-
:return: permuted value
60-
"""
61-
value_ret = value.clone() # cannot use zeros_like b/c of missing impl.
62-
orig_slices = [slice(None)] * (dim + 1)
63-
perm_slices = [slice(None)] * (dim + 1)
64-
65-
for index, perm_index in enumerate(perm):
66-
orig_slices[dim] = index
67-
perm_slices[dim] = perm_index
68-
value_ret[tuple(orig_slices)] = value[tuple(perm_slices)]
69-
70-
return value_ret
33+
return value.index_select(dim, perm)

0 commit comments

Comments
 (0)