|
29 | 29 | QUANTIZATION_CONFIG_NAME,
|
30 | 30 | QUANTIZATION_METHOD_NAME,
|
31 | 31 | SPARSITY_CONFIG_NAME,
|
| 32 | + TRANSFORM_CONFIG_NAME, |
32 | 33 | )
|
33 | 34 | from compressed_tensors.compressors.base import BaseCompressor
|
34 | 35 | from compressed_tensors.compressors.sparse_compressors import DenseCompressor
|
|
43 | 44 | )
|
44 | 45 | from compressed_tensors.quantization.lifecycle import expand_target_names
|
45 | 46 | from compressed_tensors.quantization.utils import is_module_quantized
|
| 47 | +from compressed_tensors.transform import TransformConfig |
46 | 48 | from compressed_tensors.utils import (
|
47 | 49 | align_module_device,
|
48 | 50 | delete_offload_parameter,
|
@@ -105,6 +107,7 @@ class ModelCompressor:
|
105 | 107 |
|
106 | 108 | sparsity_config: Optional[SparsityCompressionConfig] = None
|
107 | 109 | quantization_config: Optional[QuantizationConfig] = None
|
| 110 | + transform_config: Optional[TransformConfig] = None |
108 | 111 |
|
109 | 112 | @classmethod
|
110 | 113 | def from_pretrained(
|
@@ -144,6 +147,8 @@ def from_compression_config(
|
144 | 147 |
|
145 | 148 | sparsity_config = cls.parse_sparsity_config(compression_config)
|
146 | 149 | quantization_config = cls.parse_quantization_config(compression_config)
|
| 150 | + # TODO: transform config is not support by CompressedTensorsConfig yet |
| 151 | + |
147 | 152 | if sparsity_config is None and quantization_config is None:
|
148 | 153 | return None
|
149 | 154 |
|
@@ -177,20 +182,27 @@ def from_pretrained_model(
|
177 | 182 | algorithm
|
178 | 183 | :return: compressor for the configs, or None if model is not compressed
|
179 | 184 | """
|
| 185 | + # reconstruct config from schemes attached to modules |
180 | 186 | quantization_config = QuantizationConfig.from_pretrained(
|
181 | 187 | model, format=quantization_format
|
182 | 188 | )
|
183 | 189 |
|
| 190 | + # use config passed as argument |
184 | 191 | if isinstance(sparsity_config, str): # we passed in a sparsity format
|
185 | 192 | sparsity_config = SparsityCompressionConfig.load_from_registry(
|
186 | 193 | sparsity_config
|
187 | 194 | )
|
188 | 195 |
|
189 |
| - if sparsity_config is None and quantization_config is None: |
| 196 | + # use config attached to model |
| 197 | + transform_config = getattr(model, TRANSFORM_CONFIG_NAME, None) |
| 198 | + |
| 199 | + if not any((quantization_config, sparsity_config, transform_config)): |
190 | 200 | return None
|
191 | 201 |
|
192 | 202 | return cls(
|
193 |
| - sparsity_config=sparsity_config, quantization_config=quantization_config |
| 203 | + sparsity_config=sparsity_config, |
| 204 | + quantization_config=quantization_config, |
| 205 | + transform_config=transform_config, |
194 | 206 | )
|
195 | 207 |
|
196 | 208 | @staticmethod
|
@@ -254,13 +266,17 @@ def __init__(
|
254 | 266 | self,
|
255 | 267 | sparsity_config: Optional[SparsityCompressionConfig] = None,
|
256 | 268 | quantization_config: Optional[QuantizationConfig] = None,
|
| 269 | + transform_config: Optional[TransformConfig] = None, |
257 | 270 | ):
|
258 | 271 | self.sparsity_config = sparsity_config
|
259 | 272 | self.quantization_config = quantization_config
|
| 273 | + self.transform_config = transform_config |
| 274 | + |
260 | 275 | self.sparsity_compressor = None
|
261 | 276 | self.quantization_compressor: Optional[
|
262 | 277 | Union[BaseQuantizationCompressor, DenseCompressor]
|
263 | 278 | ] = None
|
| 279 | + # no transform compressor is required |
264 | 280 |
|
265 | 281 | if sparsity_config is not None:
|
266 | 282 | self.sparsity_compressor = BaseCompressor.load_from_registry(
|
@@ -640,43 +656,49 @@ def update_config(self, save_directory: str):
|
640 | 656 |
|
641 | 657 | :param save_directory: path to a folder containing a HF model config
|
642 | 658 | """
|
643 |
| - if self.quantization_config is None and self.sparsity_config is None: |
| 659 | + # this check is also done in `from_pretrained_model`, |
| 660 | + # but not in `from_pretrained`` or `from_compression_config`` |
| 661 | + if not any( |
| 662 | + (self.quantization_config, self.sparsity_config, self.transform_config) |
| 663 | + ): |
644 | 664 | return
|
645 | 665 |
|
| 666 | + # write to config.json file, regardless of whether it exists already |
| 667 | + # overwrite previous config and version if already existing |
646 | 668 | config_file_path = os.path.join(save_directory, CONFIG_NAME)
|
647 |
| - if not os.path.exists(config_file_path): |
648 |
| - _LOGGER.warning( |
649 |
| - f"Could not find a valid model config file in " |
650 |
| - f"{save_directory}. Compression config will not be saved." |
651 |
| - ) |
652 |
| - return |
| 669 | + if os.path.exists(config_file_path): |
| 670 | + with open(config_file_path, "r") as file: |
| 671 | + config_data = json.load(file) |
| 672 | + else: |
| 673 | + config_data = {} |
653 | 674 |
|
654 |
| - with open(config_file_path, "r") as config_file: |
655 |
| - config_data = json.load(config_file) |
| 675 | + # serialize configs into json |
| 676 | + qconfig_data = ( |
| 677 | + self.quantization_config.model_dump(exclude=["quant_method", "format"]) |
| 678 | + if self.quantization_config is not None |
| 679 | + else {} |
| 680 | + ) |
| 681 | + sconfig_data = ( |
| 682 | + self.sparsity_config.model_dump() |
| 683 | + if self.sparsity_config is not None |
| 684 | + else {} |
| 685 | + ) |
| 686 | + tconfig_data = ( |
| 687 | + self.transform_config.model_dump() |
| 688 | + if self.transform_config is not None |
| 689 | + else {} |
| 690 | + ) |
656 | 691 |
|
657 |
| - # required metadata whenever a quantization or sparsity config is present |
658 |
| - # overwrite previous config and version if already existing |
659 |
| - config_data[QUANTIZATION_CONFIG_NAME] = {} |
660 |
| - config_data[QUANTIZATION_CONFIG_NAME][ |
661 |
| - COMPRESSION_VERSION_NAME |
662 |
| - ] = compressed_tensors.__version__ |
663 |
| - if self.quantization_config is not None: |
664 |
| - self.quantization_config.quant_method = DEFAULT_QUANTIZATION_METHOD |
665 |
| - else: |
666 |
| - config_data[QUANTIZATION_CONFIG_NAME][ |
667 |
| - QUANTIZATION_METHOD_NAME |
668 |
| - ] = DEFAULT_QUANTIZATION_METHOD |
669 |
| - |
670 |
| - # quantization and sparsity configs |
671 |
| - if self.quantization_config is not None: |
672 |
| - quant_config_data = self.quantization_config.model_dump() |
673 |
| - config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data |
674 |
| - if self.sparsity_config is not None: |
675 |
| - sparsity_config_data = self.sparsity_config.model_dump() |
676 |
| - config_data[QUANTIZATION_CONFIG_NAME][ |
677 |
| - SPARSITY_CONFIG_NAME |
678 |
| - ] = sparsity_config_data |
| 692 | + # construct compression (quantization) config |
| 693 | + config_data[QUANTIZATION_CONFIG_NAME] = { |
| 694 | + COMPRESSION_VERSION_NAME: compressed_tensors.__version__, |
| 695 | + QUANTIZATION_METHOD_NAME: DEFAULT_QUANTIZATION_METHOD, |
| 696 | + SPARSITY_CONFIG_NAME: sconfig_data, |
| 697 | + TRANSFORM_CONFIG_NAME: tconfig_data, |
| 698 | + **qconfig_data, |
| 699 | + } |
679 | 700 |
|
| 701 | + # write results to config.json file |
680 | 702 | with open(config_file_path, "w") as config_file:
|
681 | 703 | json.dump(config_data, config_file, indent=2, sort_keys=True)
|
682 | 704 |
|
|
0 commit comments