|
29 | 29 | QUANTIZATION_CONFIG_NAME,
|
30 | 30 | QUANTIZATION_METHOD_NAME,
|
31 | 31 | SPARSITY_CONFIG_NAME,
|
32 |
| - TRANSFORM_CONFIG_NAME, |
33 | 32 | )
|
34 | 33 | from compressed_tensors.compressors.base import BaseCompressor
|
35 | 34 | from compressed_tensors.compressors.sparse_compressors import DenseCompressor
|
|
44 | 43 | )
|
45 | 44 | from compressed_tensors.quantization.lifecycle import expand_target_names
|
46 | 45 | from compressed_tensors.quantization.utils import is_module_quantized
|
47 |
| -from compressed_tensors.transform import TransformConfig |
48 | 46 | from compressed_tensors.utils import (
|
49 | 47 | align_module_device,
|
50 | 48 | delete_offload_parameter,
|
@@ -107,7 +105,6 @@ class ModelCompressor:
|
107 | 105 |
|
108 | 106 | sparsity_config: Optional[SparsityCompressionConfig] = None
|
109 | 107 | quantization_config: Optional[QuantizationConfig] = None
|
110 |
| - transform_config: Optional[TransformConfig] = None |
111 | 108 |
|
112 | 109 | @classmethod
|
113 | 110 | def from_pretrained(
|
@@ -147,8 +144,6 @@ def from_compression_config(
|
147 | 144 |
|
148 | 145 | sparsity_config = cls.parse_sparsity_config(compression_config)
|
149 | 146 | quantization_config = cls.parse_quantization_config(compression_config)
|
150 |
| - # TODO: transform config is not support by CompressedTensorsConfig yet |
151 |
| - |
152 | 147 | if sparsity_config is None and quantization_config is None:
|
153 | 148 | return None
|
154 | 149 |
|
@@ -182,27 +177,20 @@ def from_pretrained_model(
|
182 | 177 | algorithm
|
183 | 178 | :return: compressor for the configs, or None if model is not compressed
|
184 | 179 | """
|
185 |
| - # reconstruct config from schemes attached to modules |
186 | 180 | quantization_config = QuantizationConfig.from_pretrained(
|
187 | 181 | model, format=quantization_format
|
188 | 182 | )
|
189 | 183 |
|
190 |
| - # use config passed as argument |
191 | 184 | if isinstance(sparsity_config, str): # we passed in a sparsity format
|
192 | 185 | sparsity_config = SparsityCompressionConfig.load_from_registry(
|
193 | 186 | sparsity_config
|
194 | 187 | )
|
195 | 188 |
|
196 |
| - # use config attached to model |
197 |
| - transform_config = getattr(model, TRANSFORM_CONFIG_NAME, None) |
198 |
| - |
199 |
| - if not any((quantization_config, sparsity_config, transform_config)): |
| 189 | + if sparsity_config is None and quantization_config is None: |
200 | 190 | return None
|
201 | 191 |
|
202 | 192 | return cls(
|
203 |
| - sparsity_config=sparsity_config, |
204 |
| - quantization_config=quantization_config, |
205 |
| - transform_config=transform_config, |
| 193 | + sparsity_config=sparsity_config, quantization_config=quantization_config |
206 | 194 | )
|
207 | 195 |
|
208 | 196 | @staticmethod
|
@@ -266,17 +254,13 @@ def __init__(
|
266 | 254 | self,
|
267 | 255 | sparsity_config: Optional[SparsityCompressionConfig] = None,
|
268 | 256 | quantization_config: Optional[QuantizationConfig] = None,
|
269 |
| - transform_config: Optional[TransformConfig] = None, |
270 | 257 | ):
|
271 | 258 | self.sparsity_config = sparsity_config
|
272 | 259 | self.quantization_config = quantization_config
|
273 |
| - self.transform_config = transform_config |
274 |
| - |
275 | 260 | self.sparsity_compressor = None
|
276 | 261 | self.quantization_compressor: Optional[
|
277 | 262 | Union[BaseQuantizationCompressor, DenseCompressor]
|
278 | 263 | ] = None
|
279 |
| - # no transform compressor is required |
280 | 264 |
|
281 | 265 | if sparsity_config is not None:
|
282 | 266 | self.sparsity_compressor = BaseCompressor.load_from_registry(
|
@@ -656,49 +640,43 @@ def update_config(self, save_directory: str):
|
656 | 640 |
|
657 | 641 | :param save_directory: path to a folder containing a HF model config
|
658 | 642 | """
|
659 |
| - # this check is also done in `from_pretrained_model`, |
660 |
| - # but not in `from_pretrained`` or `from_compression_config`` |
661 |
| - if not any( |
662 |
| - (self.quantization_config, self.sparsity_config, self.transform_config) |
663 |
| - ): |
| 643 | + if self.quantization_config is None and self.sparsity_config is None: |
664 | 644 | return
|
665 | 645 |
|
666 |
| - # write to config.json file, regardless of whether it exists already |
667 |
| - # overwrite previous config and version if already existing |
668 | 646 | config_file_path = os.path.join(save_directory, CONFIG_NAME)
|
669 |
| - if os.path.exists(config_file_path): |
670 |
| - with open(config_file_path, "r") as file: |
671 |
| - config_data = json.load(file) |
672 |
| - else: |
673 |
| - config_data = {} |
| 647 | + if not os.path.exists(config_file_path): |
| 648 | + _LOGGER.warning( |
| 649 | + f"Could not find a valid model config file in " |
| 650 | + f"{save_directory}. Compression config will not be saved." |
| 651 | + ) |
| 652 | + return |
674 | 653 |
|
675 |
| - # serialize configs into json |
676 |
| - qconfig_data = ( |
677 |
| - self.quantization_config.model_dump(exclude=["quant_method", "format"]) |
678 |
| - if self.quantization_config is not None |
679 |
| - else {} |
680 |
| - ) |
681 |
| - sconfig_data = ( |
682 |
| - self.sparsity_config.model_dump() |
683 |
| - if self.sparsity_config is not None |
684 |
| - else {} |
685 |
| - ) |
686 |
| - tconfig_data = ( |
687 |
| - self.transform_config.model_dump() |
688 |
| - if self.transform_config is not None |
689 |
| - else {} |
690 |
| - ) |
| 654 | + with open(config_file_path, "r") as config_file: |
| 655 | + config_data = json.load(config_file) |
691 | 656 |
|
692 |
| - # construct compression (quantization) config |
693 |
| - config_data[QUANTIZATION_CONFIG_NAME] = { |
694 |
| - COMPRESSION_VERSION_NAME: compressed_tensors.__version__, |
695 |
| - QUANTIZATION_METHOD_NAME: DEFAULT_QUANTIZATION_METHOD, |
696 |
| - SPARSITY_CONFIG_NAME: sconfig_data, |
697 |
| - TRANSFORM_CONFIG_NAME: tconfig_data, |
698 |
| - **qconfig_data, |
699 |
| - } |
| 657 | + # required metadata whenever a quantization or sparsity config is present |
| 658 | + # overwrite previous config and version if already existing |
| 659 | + config_data[QUANTIZATION_CONFIG_NAME] = {} |
| 660 | + config_data[QUANTIZATION_CONFIG_NAME][ |
| 661 | + COMPRESSION_VERSION_NAME |
| 662 | + ] = compressed_tensors.__version__ |
| 663 | + if self.quantization_config is not None: |
| 664 | + self.quantization_config.quant_method = DEFAULT_QUANTIZATION_METHOD |
| 665 | + else: |
| 666 | + config_data[QUANTIZATION_CONFIG_NAME][ |
| 667 | + QUANTIZATION_METHOD_NAME |
| 668 | + ] = DEFAULT_QUANTIZATION_METHOD |
| 669 | + |
| 670 | + # quantization and sparsity configs |
| 671 | + if self.quantization_config is not None: |
| 672 | + quant_config_data = self.quantization_config.model_dump() |
| 673 | + config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data |
| 674 | + if self.sparsity_config is not None: |
| 675 | + sparsity_config_data = self.sparsity_config.model_dump() |
| 676 | + config_data[QUANTIZATION_CONFIG_NAME][ |
| 677 | + SPARSITY_CONFIG_NAME |
| 678 | + ] = sparsity_config_data |
700 | 679 |
|
701 |
| - # write results to config.json file |
702 | 680 | with open(config_file_path, "w") as config_file:
|
703 | 681 | json.dump(config_data, config_file, indent=2, sort_keys=True)
|
704 | 682 |
|
|
0 commit comments