Skip to content

Commit 45fbe5c

Browse files
committed
serialize
Signed-off-by: Kyle Sayers <[email protected]>
1 parent a7b5272 commit 45fbe5c

File tree

4 files changed

+68
-37
lines changed

4 files changed

+68
-37
lines changed

src/compressed_tensors/base.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
SPARSITY_CONFIG_NAME = "sparsity_config"
15+
# configs
1616
QUANTIZATION_CONFIG_NAME = "quantization_config"
17-
COMPRESSION_CONFIG_NAME = "compression_config"
18-
KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
17+
SPARSITY_CONFIG_NAME = "sparsity_config"
18+
TRANSFORM_CONFIG_NAME = "transform_config"
19+
20+
# required fields
1921
COMPRESSION_VERSION_NAME = "version"
2022
QUANTIZATION_METHOD_NAME = "quant_method"
23+
24+
# auxillary configs
25+
KV_CACHE_SCHEME_NAME = "kv_cache_scheme"

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 55 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
QUANTIZATION_CONFIG_NAME,
3030
QUANTIZATION_METHOD_NAME,
3131
SPARSITY_CONFIG_NAME,
32+
TRANSFORM_CONFIG_NAME,
3233
)
3334
from compressed_tensors.compressors.base import BaseCompressor
3435
from compressed_tensors.compressors.sparse_compressors import DenseCompressor
@@ -43,6 +44,7 @@
4344
)
4445
from compressed_tensors.quantization.lifecycle import expand_target_names
4546
from compressed_tensors.quantization.utils import is_module_quantized
47+
from compressed_tensors.transform import TransformConfig
4648
from compressed_tensors.utils import (
4749
align_module_device,
4850
delete_offload_parameter,
@@ -105,6 +107,7 @@ class ModelCompressor:
105107

106108
sparsity_config: Optional[SparsityCompressionConfig] = None
107109
quantization_config: Optional[QuantizationConfig] = None
110+
transform_config: Optional[TransformConfig] = None
108111

109112
@classmethod
110113
def from_pretrained(
@@ -144,6 +147,8 @@ def from_compression_config(
144147

145148
sparsity_config = cls.parse_sparsity_config(compression_config)
146149
quantization_config = cls.parse_quantization_config(compression_config)
150+
# NOTE: transfrom config is not support by ctconfig yet
151+
147152
if sparsity_config is None and quantization_config is None:
148153
return None
149154

@@ -177,20 +182,27 @@ def from_pretrained_model(
177182
algorithm
178183
:return: compressor for the configs, or None if model is not compressed
179184
"""
185+
# reconstruct config from schemes attached to modules
180186
quantization_config = QuantizationConfig.from_pretrained(
181187
model, format=quantization_format
182188
)
183189

190+
# use config passed as argument
184191
if isinstance(sparsity_config, str): # we passed in a sparsity format
185192
sparsity_config = SparsityCompressionConfig.load_from_registry(
186193
sparsity_config
187194
)
188195

189-
if sparsity_config is None and quantization_config is None:
196+
# use config attached to model
197+
transform_config = getattr(model, TRANSFORM_CONFIG_NAME, None)
198+
199+
if not any((quantization_config, sparsity_config, transform_config)):
190200
return None
191201

192202
return cls(
193-
sparsity_config=sparsity_config, quantization_config=quantization_config
203+
sparsity_config=sparsity_config,
204+
quantization_config=quantization_config,
205+
transform_config=transform_config,
194206
)
195207

196208
@staticmethod
@@ -254,13 +266,17 @@ def __init__(
254266
self,
255267
sparsity_config: Optional[SparsityCompressionConfig] = None,
256268
quantization_config: Optional[QuantizationConfig] = None,
269+
transform_config: Optional[TransformConfig] = None,
257270
):
258271
self.sparsity_config = sparsity_config
259272
self.quantization_config = quantization_config
273+
self.transform_config = transform_config
274+
260275
self.sparsity_compressor = None
261276
self.quantization_compressor: Optional[
262277
Union[BaseQuantizationCompressor, DenseCompressor]
263278
] = None
279+
# no transform compressor is required
264280

265281
if sparsity_config is not None:
266282
self.sparsity_compressor = BaseCompressor.load_from_registry(
@@ -640,43 +656,49 @@ def update_config(self, save_directory: str):
640656
641657
:param save_directory: path to a folder containing a HF model config
642658
"""
643-
if self.quantization_config is None and self.sparsity_config is None:
659+
# this check is also done in `from_pretrained_model`,
660+
# but not in `from_pretrained`` or `from_compression_config``
661+
if not any(
662+
(self.quantization_config, self.sparsity_config, self.transform_config)
663+
):
644664
return
645665

666+
# write to config.json file, regardless of whether it exists already
667+
# overwrite previous config and version if already existing
646668
config_file_path = os.path.join(save_directory, CONFIG_NAME)
647-
if not os.path.exists(config_file_path):
648-
_LOGGER.warning(
649-
f"Could not find a valid model config file in "
650-
f"{save_directory}. Compression config will not be saved."
651-
)
652-
return
669+
if os.path.exists(config_file_path):
670+
with open(config_file_path, "r") as file:
671+
config_data = json.load(file)
672+
else:
673+
config_data = {}
653674

654-
with open(config_file_path, "r") as config_file:
655-
config_data = json.load(config_file)
675+
# serialize configs into json
676+
qconfig_data = (
677+
self.quantization_config.model_dump(exclude=["quant_method", "format"])
678+
if self.quantization_config is not None
679+
else {}
680+
)
681+
sconfig_data = (
682+
self.sparsity_config.model_dump()
683+
if self.sparsity_config is not None
684+
else {}
685+
)
686+
tconfig_data = (
687+
self.transform_config.model_dump()
688+
if self.transform_config is not None
689+
else {}
690+
)
656691

657-
# required metadata whenever a quantization or sparsity config is present
658-
# overwrite previous config and version if already existing
659-
config_data[QUANTIZATION_CONFIG_NAME] = {}
660-
config_data[QUANTIZATION_CONFIG_NAME][
661-
COMPRESSION_VERSION_NAME
662-
] = compressed_tensors.__version__
663-
if self.quantization_config is not None:
664-
self.quantization_config.quant_method = DEFAULT_QUANTIZATION_METHOD
665-
else:
666-
config_data[QUANTIZATION_CONFIG_NAME][
667-
QUANTIZATION_METHOD_NAME
668-
] = DEFAULT_QUANTIZATION_METHOD
669-
670-
# quantization and sparsity configs
671-
if self.quantization_config is not None:
672-
quant_config_data = self.quantization_config.model_dump()
673-
config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
674-
if self.sparsity_config is not None:
675-
sparsity_config_data = self.sparsity_config.model_dump()
676-
config_data[QUANTIZATION_CONFIG_NAME][
677-
SPARSITY_CONFIG_NAME
678-
] = sparsity_config_data
692+
# construct compression (quantization) config
693+
config_data[QUANTIZATION_CONFIG_NAME] = {
694+
COMPRESSION_VERSION_NAME: compressed_tensors.__version__,
695+
QUANTIZATION_METHOD_NAME: DEFAULT_QUANTIZATION_METHOD,
696+
SPARSITY_CONFIG_NAME: sconfig_data,
697+
TRANSFORM_CONFIG_NAME: tconfig_data,
698+
**qconfig_data,
699+
}
679700

701+
# write results to config.json file
680702
with open(config_file_path, "w") as config_file:
681703
json.dump(config_data, config_file, indent=2, sort_keys=True)
682704

src/compressed_tensors/transform/apply.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import torch
16+
from compressed_tensors import TRANSFORM_CONFIG_NAME
1617
from compressed_tensors.transform import TransformConfig, TransformFactory
1718

1819

@@ -30,3 +31,6 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
3031
for name, scheme in config.config_groups.items():
3132
factory = TransformFactory.from_scheme(scheme, name=name)
3233
factory.apply_to_model(model)
34+
35+
# attach config to model for compression/serialization
36+
setattr(model, TRANSFORM_CONFIG_NAME, config)

src/compressed_tensors/transform/factory/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import torch
2020
import torch.nn.utils.parametrize as P
21-
from compressed_tensors import InternalModule
21+
from compressed_tensors.utils.internal import InternalModule
2222
from compressed_tensors.registry.registry import RegistryMixin, T
2323
from compressed_tensors.transform import (
2424
TransformArgs,

0 commit comments

Comments
 (0)