diff --git a/src/compressed_tensors/base.py b/src/compressed_tensors/base.py index 0e073262..f0aaa420 100644 --- a/src/compressed_tensors/base.py +++ b/src/compressed_tensors/base.py @@ -12,9 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -SPARSITY_CONFIG_NAME = "sparsity_config" +# configs QUANTIZATION_CONFIG_NAME = "quantization_config" -COMPRESSION_CONFIG_NAME = "compression_config" -KV_CACHE_SCHEME_NAME = "kv_cache_scheme" +SPARSITY_CONFIG_NAME = "sparsity_config" +TRANSFORM_CONFIG_NAME = "transform_config" + +# required fields COMPRESSION_VERSION_NAME = "version" QUANTIZATION_METHOD_NAME = "quant_method" + +# auxillary configs +KV_CACHE_SCHEME_NAME = "kv_cache_scheme" diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index f567e5a6..24dcfe95 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -29,6 +29,7 @@ QUANTIZATION_CONFIG_NAME, QUANTIZATION_METHOD_NAME, SPARSITY_CONFIG_NAME, + TRANSFORM_CONFIG_NAME, ) from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.compressors.sparse_compressors import DenseCompressor @@ -43,6 +44,7 @@ ) from compressed_tensors.quantization.lifecycle import expand_target_names from compressed_tensors.quantization.utils import is_module_quantized +from compressed_tensors.transform import TransformConfig from compressed_tensors.utils import ( align_module_device, delete_offload_parameter, @@ -105,6 +107,7 @@ class ModelCompressor: sparsity_config: Optional[SparsityCompressionConfig] = None quantization_config: Optional[QuantizationConfig] = None + transform_config: Optional[TransformConfig] = None @classmethod def from_pretrained( @@ -144,6 +147,8 @@ def from_compression_config( sparsity_config = cls.parse_sparsity_config(compression_config) quantization_config = cls.parse_quantization_config(compression_config) + # TODO: transform config is not support by CompressedTensorsConfig yet + if sparsity_config is None and quantization_config is None: return None @@ -177,20 +182,27 @@ def from_pretrained_model( algorithm :return: compressor for the configs, or None if model is not compressed """ + # reconstruct config from schemes attached to modules quantization_config = QuantizationConfig.from_pretrained( model, format=quantization_format ) + # use config passed as argument if isinstance(sparsity_config, str): # we passed in a sparsity format sparsity_config = SparsityCompressionConfig.load_from_registry( sparsity_config ) - if sparsity_config is None and quantization_config is None: + # use config attached to model + transform_config = getattr(model, TRANSFORM_CONFIG_NAME, None) + + if not any((quantization_config, sparsity_config, transform_config)): return None return cls( - sparsity_config=sparsity_config, quantization_config=quantization_config + sparsity_config=sparsity_config, + quantization_config=quantization_config, + transform_config=transform_config, ) @staticmethod @@ -254,13 +266,17 @@ def __init__( self, sparsity_config: Optional[SparsityCompressionConfig] = None, quantization_config: Optional[QuantizationConfig] = None, + transform_config: Optional[TransformConfig] = None, ): self.sparsity_config = sparsity_config self.quantization_config = quantization_config + self.transform_config = transform_config + self.sparsity_compressor = None self.quantization_compressor: Optional[ Union[BaseQuantizationCompressor, DenseCompressor] ] = None + # no transform compressor is required if sparsity_config is not None: self.sparsity_compressor = BaseCompressor.load_from_registry( @@ -640,43 +656,49 @@ def update_config(self, save_directory: str): :param save_directory: path to a folder containing a HF model config """ - if self.quantization_config is None and self.sparsity_config is None: + # this check is also done in `from_pretrained_model`, + # but not in `from_pretrained`` or `from_compression_config`` + if not any( + (self.quantization_config, self.sparsity_config, self.transform_config) + ): return + # write to config.json file, regardless of whether it exists already + # overwrite previous config and version if already existing config_file_path = os.path.join(save_directory, CONFIG_NAME) - if not os.path.exists(config_file_path): - _LOGGER.warning( - f"Could not find a valid model config file in " - f"{save_directory}. Compression config will not be saved." - ) - return + if os.path.exists(config_file_path): + with open(config_file_path, "r") as file: + config_data = json.load(file) + else: + config_data = {} - with open(config_file_path, "r") as config_file: - config_data = json.load(config_file) + # serialize configs into json + qconfig_data = ( + self.quantization_config.model_dump(exclude=["quant_method", "format"]) + if self.quantization_config is not None + else {} + ) + sconfig_data = ( + self.sparsity_config.model_dump() + if self.sparsity_config is not None + else {} + ) + tconfig_data = ( + self.transform_config.model_dump() + if self.transform_config is not None + else {} + ) - # required metadata whenever a quantization or sparsity config is present - # overwrite previous config and version if already existing - config_data[QUANTIZATION_CONFIG_NAME] = {} - config_data[QUANTIZATION_CONFIG_NAME][ - COMPRESSION_VERSION_NAME - ] = compressed_tensors.__version__ - if self.quantization_config is not None: - self.quantization_config.quant_method = DEFAULT_QUANTIZATION_METHOD - else: - config_data[QUANTIZATION_CONFIG_NAME][ - QUANTIZATION_METHOD_NAME - ] = DEFAULT_QUANTIZATION_METHOD - - # quantization and sparsity configs - if self.quantization_config is not None: - quant_config_data = self.quantization_config.model_dump() - config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data - if self.sparsity_config is not None: - sparsity_config_data = self.sparsity_config.model_dump() - config_data[QUANTIZATION_CONFIG_NAME][ - SPARSITY_CONFIG_NAME - ] = sparsity_config_data + # construct compression (quantization) config + config_data[QUANTIZATION_CONFIG_NAME] = { + COMPRESSION_VERSION_NAME: compressed_tensors.__version__, + QUANTIZATION_METHOD_NAME: DEFAULT_QUANTIZATION_METHOD, + SPARSITY_CONFIG_NAME: sconfig_data, + TRANSFORM_CONFIG_NAME: tconfig_data, + **qconfig_data, + } + # write results to config.json file with open(config_file_path, "w") as config_file: json.dump(config_data, config_file, indent=2, sort_keys=True) diff --git a/src/compressed_tensors/transform/apply.py b/src/compressed_tensors/transform/apply.py index a5d4c8c2..e247e702 100644 --- a/src/compressed_tensors/transform/apply.py +++ b/src/compressed_tensors/transform/apply.py @@ -13,6 +13,7 @@ # limitations under the License. import torch +from compressed_tensors import TRANSFORM_CONFIG_NAME from compressed_tensors.transform import TransformConfig, TransformFactory @@ -30,3 +31,6 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig): for name, scheme in config.config_groups.items(): factory = TransformFactory.from_scheme(scheme, name=name) factory.apply_to_model(model) + + # attach config to model for compression/serialization + setattr(model, TRANSFORM_CONFIG_NAME, config) diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 2218bd30..a7744709 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -14,11 +14,10 @@ from abc import ABC, abstractmethod from collections import defaultdict -from typing import List, Optional, Tuple, Set +from typing import List, Optional, Set, Tuple import torch import torch.nn.utils.parametrize as P -from compressed_tensors import InternalModule from compressed_tensors.registry.registry import RegistryMixin, T from compressed_tensors.transform import ( TransformArgs, @@ -34,6 +33,7 @@ register_offload_module, update_offload_parameter, ) +from compressed_tensors.utils.internal import InternalModule from torch import Tensor from torch.nn import Module, Parameter