Skip to content

Commit e4d352b

Browse files
committed
support more than one quant compressor
1 parent b2df366 commit e4d352b

File tree

4 files changed

+36
-14
lines changed

4 files changed

+36
-14
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def from_pretrained_model(
164164
cls,
165165
model: Module,
166166
sparsity_config: Union[SparsityCompressionConfig, str, None] = None,
167-
quantization_format: Optional[str] = None,
167+
quantization_format: Optional[List[str]] = None,
168168
) -> Optional["ModelCompressor"]:
169169
"""
170170
Given a pytorch model and optional sparsity and/or quantization configs,
@@ -267,9 +267,18 @@ def __init__(
267267
sparsity_config.format, config=sparsity_config
268268
)
269269
if quantization_config is not None:
270-
self.quantization_compressor = BaseCompressor.load_from_registry(
271-
quantization_config.format, config=quantization_config
272-
)
270+
if isinstance(quantization_config.format, list):
271+
self.quantization_compressor = {}
272+
for format in quantization_config.format:
273+
self.quantization_compressor[
274+
format
275+
] = BaseCompressor.load_from_registry(
276+
format, config=quantization_config
277+
)
278+
else:
279+
self.quantization_compressor = BaseCompressor.load_from_registry(
280+
quantization_config.format, config=quantization_config
281+
)
273282

274283
# ----- used by hf quantizer ----- #
275284

@@ -407,12 +416,23 @@ def compress_model(self, model: Module):
407416

408417
# quantization first
409418
if prefix in module_to_scheme:
410-
state_dict = self.quantization_compressor.compress(
411-
state_dict,
412-
names_to_scheme=module_to_scheme,
413-
show_progress=False,
414-
compression_device=exec_device,
415-
)
419+
if isinstance(self.quantization_compressor, dict):
420+
quant_compressor = self.quantization_compressor.get(
421+
module.quantization_scheme.format
422+
)
423+
state_dict = quant_compressor.compress(
424+
state_dict,
425+
names_to_scheme=module_to_scheme,
426+
show_progress=False,
427+
compression_device=exec_device,
428+
)
429+
else:
430+
state_dict = self.quantization_compressor.compress(
431+
state_dict,
432+
names_to_scheme=module_to_scheme,
433+
show_progress=False,
434+
compression_device=exec_device,
435+
)
416436

417437
# sparsity second
418438
if prefix in sparse_compression_targets:

src/compressed_tensors/quantization/quant_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ class QuantizationConfig(BaseModel):
138138
config_groups: Dict[str, Union[QuantizationScheme, List[str]]]
139139
quant_method: str = DEFAULT_QUANTIZATION_METHOD
140140
kv_cache_scheme: Optional[QuantizationArgs] = None
141-
format: str = DEFAULT_QUANTIZATION_FORMAT
141+
format: Union[List[str], str] = DEFAULT_QUANTIZATION_FORMAT
142142
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
143143
global_compression_ratio: Optional[float] = None
144144
ignore: Optional[List[str]] = Field(default_factory=list)
@@ -162,7 +162,7 @@ def to_dict(self):
162162

163163
@staticmethod
164164
def from_pretrained(
165-
model: Module, format: Optional[str] = None
165+
model: Module, format: Optional[Union[List[str], str]] = None
166166
) -> Optional["QuantizationConfig"]:
167167
"""
168168
Converts a model into its associated QuantizationConfig based on the
@@ -228,7 +228,7 @@ def from_pretrained(
228228

229229
if format is None:
230230
if quantization_status == QuantizationStatus.COMPRESSED:
231-
format = CompressionFormat.int_quantized.value
231+
format = CompressionFormat.int_quantized.value # why?!
232232
else:
233233
format = CompressionFormat.dense.value
234234

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from copy import deepcopy
1717
from typing import Any, Dict, List, Optional
1818

19+
from compressed_tensors.config import CompressionFormat
1920
from compressed_tensors.quantization.quant_args import (
2021
DynamicType,
2122
QuantizationArgs,
@@ -48,6 +49,7 @@ class QuantizationScheme(BaseModel):
4849
weights: Optional[QuantizationArgs] = None
4950
input_activations: Optional[QuantizationArgs] = None
5051
output_activations: Optional[QuantizationArgs] = None
52+
format: Optional[CompressionFormat] = None
5153

5254
@model_validator(mode="after")
5355
def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":

src/compressed_tensors/transform/factory/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from abc import ABC, abstractmethod
1616
from collections import defaultdict
17-
from typing import List, Optional, Tuple, Set
17+
from typing import List, Optional, Set, Tuple
1818

1919
import torch
2020
import torch.nn.utils.parametrize as P

0 commit comments

Comments
 (0)