Skip to content

Commit 30ae05c

Browse files
committed
clean-up; add mixed-precision format
1 parent 5d6ebe8 commit 30ae05c

File tree

3 files changed

+55
-42
lines changed

3 files changed

+55
-42
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,13 @@ def from_pretrained_model(
182182
algorithm
183183
:return: compressor for the configs, or None if model is not compressed
184184
"""
185-
# reconstruct config from schemes attached to modules
185+
# assume multiple compression formats means mixed-precision
186+
# as we currently only support one compressor per precision type and scheme
187+
if len(quantization_format) > 1:
188+
quantization_format = CompressionFormat.mixed_precision
189+
else:
190+
quantization_format = quantization_format[0]
191+
186192
quantization_config = QuantizationConfig.from_pretrained(
187193
model, format=quantization_format
188194
)
@@ -263,6 +269,17 @@ def parse_quantization_config(
263269

264270
return quantization_config
265271

272+
def _fetch_unique_quantization_formats(self):
273+
"""
274+
Get all unique compression formats used in
275+
model
276+
"""
277+
quantization_formats = []
278+
for _, scheme in self.quantization_config.config_groups.items():
279+
if scheme.format not in quantization_formats:
280+
quantization_formats.append(scheme)
281+
return quantization_formats
282+
266283
def __init__(
267284
self,
268285
sparsity_config: Optional[SparsityCompressionConfig] = None,
@@ -275,26 +292,24 @@ def __init__(
275292

276293
self.sparsity_compressor = None
277294
self.quantization_compressor: Optional[
278-
Union[BaseQuantizationCompressor, DenseCompressor]
295+
Dict[str, Union[BaseQuantizationCompressor, DenseCompressor]]
279296
] = None
280297
# no transform compressor is required
281298

282299
if sparsity_config is not None:
283300
self.sparsity_compressor = BaseCompressor.load_from_registry(
284301
sparsity_config.format, config=sparsity_config
285302
)
303+
304+
quantization_formats = self._fetch_unique_quantization_formats()
305+
286306
if quantization_config is not None:
287-
if isinstance(quantization_config.format, list):
288-
self.quantization_compressor = {}
289-
for format in quantization_config.format:
290-
self.quantization_compressor[
291-
format
292-
] = BaseCompressor.load_from_registry(
293-
format, config=quantization_config
294-
)
295-
else:
296-
self.quantization_compressor = BaseCompressor.load_from_registry(
297-
quantization_config.format, config=quantization_config
307+
self.quantization_compressor = {}
308+
for format in quantization_formats:
309+
self.quantization_compressor[
310+
format
311+
] = BaseCompressor.load_from_registry(
312+
format, config=quantization_config
298313
)
299314

300315
# ----- used by hf quantizer ----- #
@@ -433,23 +448,15 @@ def compress_model(self, model: Module):
433448

434449
# quantization first
435450
if prefix in module_to_scheme:
436-
if isinstance(self.quantization_compressor, dict):
437-
quant_compressor = self.quantization_compressor.get(
438-
module.quantization_scheme.format
439-
)
440-
state_dict = quant_compressor.compress(
441-
state_dict,
442-
names_to_scheme=module_to_scheme,
443-
show_progress=False,
444-
compression_device=exec_device,
445-
)
446-
else:
447-
state_dict = self.quantization_compressor.compress(
448-
state_dict,
449-
names_to_scheme=module_to_scheme,
450-
show_progress=False,
451-
compression_device=exec_device,
452-
)
451+
quant_compressor = self.quantization_compressor.get(
452+
module.quantization_scheme.format
453+
)
454+
state_dict = quant_compressor.compress(
455+
state_dict,
456+
names_to_scheme=module_to_scheme,
457+
show_progress=False,
458+
compression_device=exec_device,
459+
)
453460

454461
# sparsity second
455462
if prefix in sparse_compression_targets:
@@ -515,12 +522,13 @@ def decompress_model(self, model: Module):
515522

516523
# quantization second
517524
if prefix in module_to_scheme:
518-
state_dict = (
519-
self.quantization_compressor.decompress_module_from_state_dict(
520-
prefix,
521-
state_dict,
522-
scheme=module_to_scheme[prefix],
523-
)
525+
quant_compressor = self.quantization_compressor.get(
526+
module.quantization_scheme.format
527+
)
528+
state_dict = quant_compressor.decompress_module_from_state_dict(
529+
prefix,
530+
state_dict,
531+
scheme=module_to_scheme[prefix],
524532
)
525533

526534
# remove any existing parameters
@@ -559,7 +567,9 @@ def compress(
559567

560568
if self.quantization_compressor is not None:
561569
module_to_scheme = map_module_to_scheme(model)
562-
state_dict = self.quantization_compressor.compress(
570+
# Note - compress only supports one compression format atm
571+
quant_compressor = next(iter(self.quantization_compressor))
572+
state_dict = quant_compressor.compress(
563573
state_dict,
564574
names_to_scheme=module_to_scheme,
565575
show_progress=show_progress,
@@ -613,9 +623,11 @@ def decompress(self, model_path: str, model: Module):
613623
self.sparsity_compressor is not None
614624
and self.sparsity_config.format != CompressionFormat.dense.value
615625
):
626+
# note - decompress only support one compressor so far
627+
quant_compressor = next(iter(self.quantization_compressor))
616628
params_to_ignore = None
617629
if self.quantization_compressor is not None:
618-
params_to_ignore = self.quantization_compressor.compression_param_names
630+
params_to_ignore = quant_compressor.compression_param_names
619631
# Sparse decompression is applied on the model_path
620632
# The compressor will try and load any quantization parameters as well
621633
# params_to_skip_load will skip over quantization params from being loaded
@@ -626,7 +638,7 @@ def decompress(self, model_path: str, model: Module):
626638
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
627639
sparse_decompressed = True
628640

629-
if self.quantization_compressor is not None:
641+
if quant_compressor is not None:
630642
# Temporarily set quantization status to FROZEN to prevent
631643
# quantization during apply_quantization_config. This ensures
632644
# that the dtypes of the weights are not unintentionally updated.
@@ -649,15 +661,15 @@ def decompress(self, model_path: str, model: Module):
649661
# including initialization
650662
load_weight_quantization=(
651663
sparse_decompressed
652-
or isinstance(self.quantization_compressor, DenseCompressor)
664+
or isinstance(quant_compressor, DenseCompressor)
653665
),
654666
)
655667

656668
model_path_or_state_dict = (
657669
model.state_dict() if sparse_decompressed else model_path
658670
)
659671

660-
dense_gen = self.quantization_compressor.decompress(
672+
dense_gen = quant_compressor.decompress(
661673
model_path_or_state_dict, names_to_scheme=names_to_scheme
662674
)
663675
# TODO: all weight quantization params will be moved to the compressor

src/compressed_tensors/config/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class CompressionFormat(Enum):
3232
naive_quantized = "naive-quantized"
3333
pack_quantized = "pack-quantized"
3434
marlin_24 = "marlin-24"
35+
mixed_precision = "mixed-precision"
3536
nvfp4_pack_quantized = "nvfp4-pack-quantized"
3637

3738

src/compressed_tensors/quantization/quant_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ class QuantizationConfig(BaseModel):
138138
config_groups: Dict[str, Union[QuantizationScheme, List[str]]]
139139
quant_method: str = DEFAULT_QUANTIZATION_METHOD
140140
kv_cache_scheme: Optional[QuantizationArgs] = None
141-
format: Union[List[str], str] = DEFAULT_QUANTIZATION_FORMAT
141+
format: str = DEFAULT_QUANTIZATION_FORMAT
142142
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
143143
global_compression_ratio: Optional[float] = None
144144
ignore: Optional[List[str]] = Field(default_factory=list)

0 commit comments

Comments
 (0)