Skip to content

Commit 926c2bc

Browse files
committed
clean-up; add mixed-precision format
1 parent e4d352b commit 926c2bc

File tree

3 files changed

+57
-42
lines changed

3 files changed

+57
-42
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 55 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,13 @@ def from_pretrained_model(
177177
algorithm
178178
:return: compressor for the configs, or None if model is not compressed
179179
"""
180+
# assume multiple compression formats means mixed-precision
181+
# as we currently only support one compressor per precision type and scheme
182+
if len(quantization_format) > 1:
183+
quantization_format = CompressionFormat.mixed_precision
184+
else:
185+
quantization_format = quantization_format[0]
186+
180187
quantization_config = QuantizationConfig.from_pretrained(
181188
model, format=quantization_format
182189
)
@@ -190,7 +197,8 @@ def from_pretrained_model(
190197
return None
191198

192199
return cls(
193-
sparsity_config=sparsity_config, quantization_config=quantization_config
200+
sparsity_config=sparsity_config,
201+
quantization_config=quantization_config,
194202
)
195203

196204
@staticmethod
@@ -250,6 +258,17 @@ def parse_quantization_config(
250258

251259
return quantization_config
252260

261+
def _fetch_unique_quantization_formats(self):
262+
"""
263+
Get all unique compression formats used in
264+
model
265+
"""
266+
quantization_formats = []
267+
for _, scheme in self.quantization_config.config_groups.items():
268+
if scheme.format not in quantization_formats:
269+
quantization_formats.append(scheme)
270+
return quantization_formats
271+
253272
def __init__(
254273
self,
255274
sparsity_config: Optional[SparsityCompressionConfig] = None,
@@ -259,25 +278,23 @@ def __init__(
259278
self.quantization_config = quantization_config
260279
self.sparsity_compressor = None
261280
self.quantization_compressor: Optional[
262-
Union[BaseQuantizationCompressor, DenseCompressor]
281+
Dict[str, Union[BaseQuantizationCompressor, DenseCompressor]]
263282
] = None
264283

265284
if sparsity_config is not None:
266285
self.sparsity_compressor = BaseCompressor.load_from_registry(
267286
sparsity_config.format, config=sparsity_config
268287
)
288+
289+
quantization_formats = self._fetch_unique_quantization_formats()
290+
269291
if quantization_config is not None:
270-
if isinstance(quantization_config.format, list):
271-
self.quantization_compressor = {}
272-
for format in quantization_config.format:
273-
self.quantization_compressor[
274-
format
275-
] = BaseCompressor.load_from_registry(
276-
format, config=quantization_config
277-
)
278-
else:
279-
self.quantization_compressor = BaseCompressor.load_from_registry(
280-
quantization_config.format, config=quantization_config
292+
self.quantization_compressor = {}
293+
for format in quantization_formats:
294+
self.quantization_compressor[
295+
format
296+
] = BaseCompressor.load_from_registry(
297+
format, config=quantization_config
281298
)
282299

283300
# ----- used by hf quantizer ----- #
@@ -416,23 +433,15 @@ def compress_model(self, model: Module):
416433

417434
# quantization first
418435
if prefix in module_to_scheme:
419-
if isinstance(self.quantization_compressor, dict):
420-
quant_compressor = self.quantization_compressor.get(
421-
module.quantization_scheme.format
422-
)
423-
state_dict = quant_compressor.compress(
424-
state_dict,
425-
names_to_scheme=module_to_scheme,
426-
show_progress=False,
427-
compression_device=exec_device,
428-
)
429-
else:
430-
state_dict = self.quantization_compressor.compress(
431-
state_dict,
432-
names_to_scheme=module_to_scheme,
433-
show_progress=False,
434-
compression_device=exec_device,
435-
)
436+
quant_compressor = self.quantization_compressor.get(
437+
module.quantization_scheme.format
438+
)
439+
state_dict = quant_compressor.compress(
440+
state_dict,
441+
names_to_scheme=module_to_scheme,
442+
show_progress=False,
443+
compression_device=exec_device,
444+
)
436445

437446
# sparsity second
438447
if prefix in sparse_compression_targets:
@@ -498,12 +507,13 @@ def decompress_model(self, model: Module):
498507

499508
# quantization second
500509
if prefix in module_to_scheme:
501-
state_dict = (
502-
self.quantization_compressor.decompress_module_from_state_dict(
503-
prefix,
504-
state_dict,
505-
scheme=module_to_scheme[prefix],
506-
)
510+
quant_compressor = self.quantization_compressor.get(
511+
module.quantization_scheme.format
512+
)
513+
state_dict = quant_compressor.decompress_module_from_state_dict(
514+
prefix,
515+
state_dict,
516+
scheme=module_to_scheme[prefix],
507517
)
508518

509519
# remove any existing parameters
@@ -542,7 +552,9 @@ def compress(
542552

543553
if self.quantization_compressor is not None:
544554
module_to_scheme = map_module_to_scheme(model)
545-
state_dict = self.quantization_compressor.compress(
555+
# Note - compress only supports one compression format atm
556+
quant_compressor = next(iter(self.quantization_compressor))
557+
state_dict = quant_compressor.compress(
546558
state_dict,
547559
names_to_scheme=module_to_scheme,
548560
show_progress=show_progress,
@@ -596,9 +608,11 @@ def decompress(self, model_path: str, model: Module):
596608
self.sparsity_compressor is not None
597609
and self.sparsity_config.format != CompressionFormat.dense.value
598610
):
611+
# note - decompress only support one compressor so far
612+
quant_compressor = next(iter(self.quantization_compressor))
599613
params_to_ignore = None
600614
if self.quantization_compressor is not None:
601-
params_to_ignore = self.quantization_compressor.compression_param_names
615+
params_to_ignore = quant_compressor.compression_param_names
602616
# Sparse decompression is applied on the model_path
603617
# The compressor will try and load any quantization parameters as well
604618
# params_to_skip_load will skip over quantization params from being loaded
@@ -609,7 +623,7 @@ def decompress(self, model_path: str, model: Module):
609623
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
610624
sparse_decompressed = True
611625

612-
if self.quantization_compressor is not None:
626+
if quant_compressor is not None:
613627
# Temporarily set quantization status to FROZEN to prevent
614628
# quantization during apply_quantization_config. This ensures
615629
# that the dtypes of the weights are not unintentionally updated.
@@ -632,15 +646,15 @@ def decompress(self, model_path: str, model: Module):
632646
# including initialization
633647
load_weight_quantization=(
634648
sparse_decompressed
635-
or isinstance(self.quantization_compressor, DenseCompressor)
649+
or isinstance(quant_compressor, DenseCompressor)
636650
),
637651
)
638652

639653
model_path_or_state_dict = (
640654
model.state_dict() if sparse_decompressed else model_path
641655
)
642656

643-
dense_gen = self.quantization_compressor.decompress(
657+
dense_gen = quant_compressor.decompress(
644658
model_path_or_state_dict, names_to_scheme=names_to_scheme
645659
)
646660
# TODO: all weight quantization params will be moved to the compressor

src/compressed_tensors/config/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class CompressionFormat(Enum):
3232
naive_quantized = "naive-quantized"
3333
pack_quantized = "pack-quantized"
3434
marlin_24 = "marlin-24"
35+
mixed_precision = "mixed-precision"
3536
nvfp4_pack_quantized = "nvfp4-pack-quantized"
3637

3738

src/compressed_tensors/quantization/quant_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ class QuantizationConfig(BaseModel):
138138
config_groups: Dict[str, Union[QuantizationScheme, List[str]]]
139139
quant_method: str = DEFAULT_QUANTIZATION_METHOD
140140
kv_cache_scheme: Optional[QuantizationArgs] = None
141-
format: Union[List[str], str] = DEFAULT_QUANTIZATION_FORMAT
141+
format: str = DEFAULT_QUANTIZATION_FORMAT
142142
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
143143
global_compression_ratio: Optional[float] = None
144144
ignore: Optional[List[str]] = Field(default_factory=list)

0 commit comments

Comments
 (0)