diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index ed196f24..51f32e94 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -169,7 +169,7 @@ def from_pretrained_model( cls, model: Module, sparsity_config: Union[SparsityCompressionConfig, str, None] = None, - quantization_format: Optional[str] = None, + quantization_format: Optional[Union[str, List[str]]] = None, ) -> Optional["ModelCompressor"]: """ Given a pytorch model and optional sparsity and/or quantization configs, @@ -182,7 +182,22 @@ def from_pretrained_model( algorithm :return: compressor for the configs, or None if model is not compressed """ - # reconstruct config from schemes attached to modules + + compression_formats = None + if quantization_format is not None: + # llmcompressor incorrectly passes in a CompressionFormat when + # the value string is expected - handle both cases + if isinstance(quantization_format, (str, CompressionFormat)): + quantization_format = [quantization_format] + + compression_formats = quantization_format + # assume multiple compression formats means mixed-precision + # as we currently only support one compressor per precision type and scheme + if len(quantization_format) > 1: + quantization_format = CompressionFormat.mixed_precision.value + else: + quantization_format = quantization_format[0] + quantization_config = QuantizationConfig.from_pretrained( model, format=quantization_format ) @@ -203,6 +218,7 @@ def from_pretrained_model( sparsity_config=sparsity_config, quantization_config=quantization_config, transform_config=transform_config, + compression_formats=compression_formats, ) @staticmethod @@ -263,19 +279,36 @@ def parse_quantization_config( return quantization_config + def _fetch_unique_quantization_formats(self) -> List[str]: + """ + Get all unique compression formats present in a model + :return: list of quantization formats + """ + quantization_formats = [] + for _, scheme in self.quantization_config.config_groups.items(): + if scheme.format is not None and scheme.format not in quantization_formats: + quantization_formats.append(scheme.format) + + # If empty list, fallback to using the global format + if len(quantization_formats) == 0: + quantization_formats.append(self.quantization_config.format) + return quantization_formats + def __init__( self, sparsity_config: Optional[SparsityCompressionConfig] = None, quantization_config: Optional[QuantizationConfig] = None, transform_config: Optional[TransformConfig] = None, + compression_formats: Optional[List[str]] = None, ): self.sparsity_config = sparsity_config self.quantization_config = quantization_config self.transform_config = transform_config + self.compression_formats = compression_formats self.sparsity_compressor = None self.quantization_compressor: Optional[ - Union[BaseQuantizationCompressor, DenseCompressor] + Dict[str, Union[BaseQuantizationCompressor, DenseCompressor]] ] = None # no transform compressor is required @@ -283,10 +316,18 @@ def __init__( self.sparsity_compressor = BaseCompressor.load_from_registry( sparsity_config.format, config=sparsity_config ) + if quantization_config is not None: - self.quantization_compressor = BaseCompressor.load_from_registry( - quantization_config.format, config=quantization_config - ) + if not self.compression_formats: + self.compression_formats = self._fetch_unique_quantization_formats() + + self.quantization_compressor = {} + for format in self.compression_formats: + self.quantization_compressor[ + format + ] = BaseCompressor.load_from_registry( + format, config=quantization_config + ) # ----- used by hf quantizer ----- # @@ -381,12 +422,13 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]: targets=scheme.targets, ignore=self.quantization_config.ignore, ) - unexpected_keys.update( - merge_names(target, param) - for target in quant_targets - for param in self.quantization_compressor.compression_param_names - if param != "weight" - ) + for quant_compressor in self.quantization_compressor.values(): + unexpected_keys.update( + merge_names(target, param) + for target in quant_targets + for param in quant_compressor.compression_param_names + if param != "weight" + ) return list(unexpected_keys) @@ -424,7 +466,25 @@ def compress_model(self, model: Module): # quantization first if prefix in module_to_scheme: - state_dict = self.quantization_compressor.compress( + if ( + not hasattr(module.quantization_scheme, "format") + or module.quantization_scheme.format is None + ): + if ( + self.quantization_config.format + == CompressionFormat.mixed_precision.value + ): + raise ValueError( + "Compressing mixed-precision models without defining " + "per module quantization_scheme.format is currently " + "not supported" + ) + format = self.quantization_config.format + else: + format = module.quantization_scheme.format + + quant_compressor = self.quantization_compressor.get(format) + state_dict = quant_compressor.compress( state_dict, names_to_scheme=module_to_scheme, show_progress=False, @@ -495,12 +555,28 @@ def decompress_model(self, model: Module): # quantization second if prefix in module_to_scheme: - state_dict = ( - self.quantization_compressor.decompress_module_from_state_dict( - prefix, - state_dict, - scheme=module_to_scheme[prefix], - ) + + if ( + not hasattr(module.quantization_scheme, "format") + or module.quantization_scheme.format is None + ): + if ( + self.quantization_config.format + == CompressionFormat.mixed_precision.value + ): + raise ValueError( + "Decompressing mixed-precision models without defining " + "per module quantization_scheme.format is currently not " + "supported" + ) + format = self.quantization_config.format + else: + format = module.quantization_scheme.format + quant_compressor = self.quantization_compressor.get(format) + state_dict = quant_compressor.decompress_module_from_state_dict( + prefix, + state_dict, + scheme=module_to_scheme[prefix], ) # remove any existing parameters @@ -539,7 +615,9 @@ def compress( if self.quantization_compressor is not None: module_to_scheme = map_module_to_scheme(model) - state_dict = self.quantization_compressor.compress( + # Note - compress only supports one compression format atm + quant_compressor = next(iter(self.quantization_compressor.values())) + state_dict = quant_compressor.compress( state_dict, names_to_scheme=module_to_scheme, show_progress=show_progress, @@ -588,14 +666,20 @@ def decompress(self, model_path: str, model: Module): """ model_path = get_safetensors_folder(model_path) sparse_decompressed = False + quant_compressor = ( + next(iter(self.quantization_compressor.values())) + if self.quantization_compressor is not None + else None + ) if ( self.sparsity_compressor is not None and self.sparsity_config.format != CompressionFormat.dense.value ): + # note - decompress only supports one compressor atm params_to_ignore = None - if self.quantization_compressor is not None: - params_to_ignore = self.quantization_compressor.compression_param_names + if quant_compressor is not None: + params_to_ignore = quant_compressor.compression_param_names # Sparse decompression is applied on the model_path # The compressor will try and load any quantization parameters as well # params_to_skip_load will skip over quantization params from being loaded @@ -606,7 +690,7 @@ def decompress(self, model_path: str, model: Module): setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config) sparse_decompressed = True - if self.quantization_compressor is not None: + if quant_compressor is not None: # Temporarily set quantization status to FROZEN to prevent # quantization during apply_quantization_config. This ensures # that the dtypes of the weights are not unintentionally updated. @@ -629,7 +713,7 @@ def decompress(self, model_path: str, model: Module): # including initialization load_weight_quantization=( sparse_decompressed - or isinstance(self.quantization_compressor, DenseCompressor) + or isinstance(quant_compressor, DenseCompressor) ), ) @@ -637,7 +721,7 @@ def decompress(self, model_path: str, model: Module): model.state_dict() if sparse_decompressed else model_path ) - dense_gen = self.quantization_compressor.decompress( + dense_gen = quant_compressor.decompress( model_path_or_state_dict, names_to_scheme=names_to_scheme ) # TODO: all weight quantization params will be moved to the compressor diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index 09f5f338..5024b1d6 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -32,6 +32,7 @@ class CompressionFormat(Enum): naive_quantized = "naive-quantized" pack_quantized = "pack-quantized" marlin_24 = "marlin-24" + mixed_precision = "mixed-precision" nvfp4_pack_quantized = "nvfp4-pack-quantized" diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 29864d25..cdb5b0f3 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -16,6 +16,7 @@ from copy import deepcopy from typing import List, Optional +from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization.quant_args import ( DynamicType, QuantizationArgs, @@ -42,12 +43,14 @@ class QuantizationScheme(BaseModel): :param weights: quantization config for layer weights :param input_activations: quantization config for layer inputs :param output_activations: quantization config for layer outputs + :param format: CompressionFormat for the layer """ targets: List[str] weights: Optional[QuantizationArgs] = None input_activations: Optional[QuantizationArgs] = None output_activations: Optional[QuantizationArgs] = None + format: Optional[str] = None @model_validator(mode="after") def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index 10f9c974..dc48870b 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -20,8 +20,12 @@ import torch import torch.nn as nn from compressed_tensors.compressors import ModelCompressor -from compressed_tensors.config import SparsityCompressionConfig -from compressed_tensors.quantization import QuantizationConfig +from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationConfig, + QuantizationScheme, +) from safetensors.torch import save_file from tests.testing_utils import induce_sparsity, requires_hf_quantizer from transformers import AutoModelForCausalLM @@ -395,7 +399,7 @@ def _get_combined_config(s_config, q_config): ) def test_compress_model(model_stub, q_format, s_config, tmpdir): model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.float32) - compressor = ModelCompressor.from_pretrained_model(model, s_config, q_format) + compressor = ModelCompressor.from_pretrained_model(model, s_config, [q_format]) # compress model by eagerly compressing state dict true_compressed = dict(compressor.compress(model)) @@ -443,7 +447,7 @@ def test_compress_model_meta(model_stub, q_format, s_config): model_stub, torch_dtype=torch.float32 ) reference_compressor = ModelCompressor.from_pretrained_model( - cpu_model, s_config, q_format + cpu_model, s_config, [q_format] ) # Only stores dtype because meta model does not store values expected = {k: v.dtype for k, v in reference_compressor.compress(cpu_model).items()} @@ -459,7 +463,7 @@ def test_compress_model_meta(model_stub, q_format, s_config): module.to_empty(device="meta") # Compress in-place on meta model - compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, q_format) + compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, [q_format]) compressor.compress_model(meta_model) # Compare keys and dtypes @@ -469,6 +473,43 @@ def test_compress_model_meta(model_stub, q_format, s_config): assert compressed[key].dtype == dtype, f"{key} has incorrect dtype" +def test_multiple_quant_compressors(): + model = torch.nn.Sequential(torch.nn.Linear(1, 2), torch.nn.Linear(2, 3)) + input_activations = QuantizationArgs(num_bits=8, type="float") + weights = QuantizationArgs(num_bits=8, type="float") + + scheme_fp8 = QuantizationScheme( + targets=["Linear"], + weights=weights, + input_activations=input_activations, + format=CompressionFormat.float_quantized.value, + ) + + input_activations = QuantizationArgs(num_bits=4, type="float") + weights = QuantizationArgs(num_bits=4, type="float") + + scheme_nvfp4 = QuantizationScheme( + targets=["Linear"], + weights=weights, + input_activations=input_activations, + format=CompressionFormat.nvfp4_pack_quantized.value, + ) + + model[0].quantization_scheme = scheme_fp8 + model[0].quantization_status = "frozen" + model[1].quantization_scheme = scheme_nvfp4 + model[1].quantization_status = "frozen" + + formats = [scheme_fp8.format, scheme_nvfp4.format] + + compressor = ModelCompressor.from_pretrained_model(model, None, formats) + assert isinstance(compressor.quantization_compressor, dict) + assert ( + compressor.quantization_config.format == CompressionFormat.mixed_precision.value + ) + assert all(format in compressor.quantization_compressor for format in formats) + + @pytest.mark.parametrize( "model_stub,comp_stub", [ diff --git a/tests/test_quantization/test_quant_scheme.py b/tests/test_quantization/test_quant_scheme.py index 0ea7f31f..d1c0d141 100644 --- a/tests/test_quantization/test_quant_scheme.py +++ b/tests/test_quantization/test_quant_scheme.py @@ -26,12 +26,13 @@ def test_basic_scheme(): assert scheme.weights == weights assert scheme.input_activations is None assert scheme.output_activations is None + assert scheme.format is None def test_full_scheme(): targets = ["Linear"] weights = QuantizationArgs() - input_activations = QuantizationArgs(num_bits=4) + input_activations = QuantizationArgs(num_bits=8) output_activations = QuantizationArgs(num_bits=8, type="float", symmetric=False) scheme = QuantizationScheme( @@ -39,11 +40,13 @@ def test_full_scheme(): weights=weights, input_activations=input_activations, output_activations=output_activations, + format="float-quantized", ) assert scheme.targets == targets assert scheme.weights == weights assert scheme.input_activations == input_activations assert scheme.output_activations == output_activations + assert scheme.format is "float-quantized" def test_needs_targets(): @@ -57,3 +60,4 @@ def test_defaults(): assert output.weights is None assert output.input_activations is None assert output.output_activations is None + assert output.format is None