Skip to content

Commit 0731aa5

Browse files
kylesayrsbrian-dellabettadsikka
authored
[Transform] Serialize transforms config (#412)
* add utilities Signed-off-by: Kyle Sayers <[email protected]> * add tests Signed-off-by: Kyle Sayers <[email protected]> * add additional tests Signed-off-by: Kyle Sayers <[email protected]> * add utils and tests Signed-off-by: Kyle Sayers <[email protected]> * Implement transform factories Signed-off-by: Kyle Sayers <[email protected]> * add permutations Signed-off-by: Kyle Sayers <[email protected]> * add delete_offload_module Signed-off-by: Kyle Sayers <[email protected]> * key inverses by weight Signed-off-by: Kyle Sayers <[email protected]> * fix tests Signed-off-by: Kyle Sayers <[email protected]> * standardize random hadamard Signed-off-by: Kyle Sayers <[email protected]> * prepend input hooks Signed-off-by: Kyle Sayers <[email protected]> * apply sqrt division first Signed-off-by: Kyle Sayers <[email protected]> * use divided hadamards Signed-off-by: Kyle Sayers <[email protected]> * fix typo Signed-off-by: Kyle Sayers <[email protected]> * add random option Signed-off-by: Kyle Sayers <[email protected]> * use random seeds, rename matrix multiply Signed-off-by: Kyle Sayers <[email protected]> * add deterministic generation to random matrix Signed-off-by: Kyle Sayers <[email protected]> * fix perm math Signed-off-by: Kyle Sayers <[email protected]> * update docstrings Signed-off-by: Kyle Sayers <[email protected]> * update docstrings Signed-off-by: Kyle Sayers <[email protected]> * cleanup Signed-off-by: Kyle Sayers <[email protected]> * cleanup 2 Signed-off-by: Kyle Sayers <[email protected]> * make seed optional Signed-off-by: Kyle Sayers <[email protected]> * remove iterable check and missing return value Signed-off-by: Kyle Sayers <[email protected]> * Remove unrelated changes * simplify code Signed-off-by: Kyle Sayers <[email protected]> * implement apply, use in tests Signed-off-by: Kyle Sayers <[email protected]> * use hadamards database file Signed-off-by: Kyle Sayers <[email protected]> * try manifest Signed-off-by: Kyle Sayers <[email protected]> * try setup, update hadamards list Signed-off-by: Kyle Sayers <[email protected]> * fix setup Signed-off-by: Kyle Sayers <[email protected]> * add docstrings, cleanup Signed-off-by: Kyle Sayers <[email protected]> * fix setup, thank you @dbarbuzzi Signed-off-by: Kyle Sayers <[email protected]> * remove numpy, add tests Signed-off-by: Kyle Sayers <[email protected]> * solidify dtype, add gpu tests Signed-off-by: Kyle Sayers <[email protected]> * fix docstring Signed-off-by: Kyle Sayers <[email protected]> * add device option Signed-off-by: Kyle Sayers <[email protected]> * construct on execution device, cache on offload device Signed-off-by: Kyle Sayers <[email protected]> * save construction device changes for later Signed-off-by: Kyle Sayers <[email protected]> * construct on execution device, cache on offload device * cite nja sloane Signed-off-by: Kyle Sayers <[email protected]> * remove dreg Signed-off-by: Kyle Sayers <[email protected]> * put on device via safe_open Signed-off-by: Kyle Sayers <[email protected]> * nits and docstrings Signed-off-by: Kyle Sayers <[email protected]> * update docstring Signed-off-by: Kyle Sayers <[email protected]> * Merge * merge with construct: construct in float32 Signed-off-by: Kyle Sayers <[email protected]> * construct with same dtype, constructing on fp32 found no difference Signed-off-by: Kyle Sayers <[email protected]> * remove unnecessary imports Signed-off-by: Kyle Sayers <[email protected]> * bugfixes (#375) Signed-off-by: Brian Dellabetta <[email protected]> * use factory_kwargs Signed-off-by: Kyle Sayers <[email protected]> * add frozen dict to deps Signed-off-by: Kyle Sayers <[email protected]> * fix style Signed-off-by: Kyle Sayers <[email protected]> * merge Signed-off-by: Kyle Sayers <[email protected]> * use delete_offload_module Signed-off-by: Kyle Sayers <[email protected]> * add docstrign Signed-off-by: Kyle Sayers <[email protected]> * use parametrize Signed-off-by: Kyle Sayers <[email protected]> * populate _dynamic_tied_weights_keys Signed-off-by: Kyle Sayers <[email protected]> * ensure serializable Signed-off-by: Kyle Sayers <[email protected]> * remove extra space Signed-off-by: Kyle Sayers <[email protected]> * apply style Signed-off-by: Kyle Sayers <[email protected]> * merge dregs * skip offloading tests until transformers changes land Signed-off-by: Kyle Sayers <[email protected]> * use set Signed-off-by: Kyle Sayers <[email protected]> * [Quantization][Decompression] Fix QDQ for dynamic quant; Update NVFP4 Compression Params (#407) * add compression param; update qdq for batch greater than 1 * make generic * fix tests * remove incorrect line change; make generic * update * serialize Signed-off-by: Kyle Sayers <[email protected]> * fix typo, comment Signed-off-by: Kyle Sayers <[email protected]> --------- Signed-off-by: Kyle Sayers <[email protected]> Signed-off-by: Brian Dellabetta <[email protected]> Co-authored-by: Brian Dellabetta <[email protected]> Co-authored-by: Dipika Sikka <[email protected]>
1 parent 131673e commit 0731aa5

File tree

4 files changed

+69
-38
lines changed

4 files changed

+69
-38
lines changed

src/compressed_tensors/base.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
SPARSITY_CONFIG_NAME = "sparsity_config"
15+
# configs
1616
QUANTIZATION_CONFIG_NAME = "quantization_config"
17-
COMPRESSION_CONFIG_NAME = "compression_config"
18-
KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
17+
SPARSITY_CONFIG_NAME = "sparsity_config"
18+
TRANSFORM_CONFIG_NAME = "transform_config"
19+
20+
# required fields
1921
COMPRESSION_VERSION_NAME = "version"
2022
QUANTIZATION_METHOD_NAME = "quant_method"
23+
24+
# auxillary configs
25+
KV_CACHE_SCHEME_NAME = "kv_cache_scheme"

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 55 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
QUANTIZATION_CONFIG_NAME,
3030
QUANTIZATION_METHOD_NAME,
3131
SPARSITY_CONFIG_NAME,
32+
TRANSFORM_CONFIG_NAME,
3233
)
3334
from compressed_tensors.compressors.base import BaseCompressor
3435
from compressed_tensors.compressors.sparse_compressors import DenseCompressor
@@ -43,6 +44,7 @@
4344
)
4445
from compressed_tensors.quantization.lifecycle import expand_target_names
4546
from compressed_tensors.quantization.utils import is_module_quantized
47+
from compressed_tensors.transform import TransformConfig
4648
from compressed_tensors.utils import (
4749
align_module_device,
4850
delete_offload_parameter,
@@ -105,6 +107,7 @@ class ModelCompressor:
105107

106108
sparsity_config: Optional[SparsityCompressionConfig] = None
107109
quantization_config: Optional[QuantizationConfig] = None
110+
transform_config: Optional[TransformConfig] = None
108111

109112
@classmethod
110113
def from_pretrained(
@@ -144,6 +147,8 @@ def from_compression_config(
144147

145148
sparsity_config = cls.parse_sparsity_config(compression_config)
146149
quantization_config = cls.parse_quantization_config(compression_config)
150+
# TODO: transform config is not support by CompressedTensorsConfig yet
151+
147152
if sparsity_config is None and quantization_config is None:
148153
return None
149154

@@ -177,20 +182,27 @@ def from_pretrained_model(
177182
algorithm
178183
:return: compressor for the configs, or None if model is not compressed
179184
"""
185+
# reconstruct config from schemes attached to modules
180186
quantization_config = QuantizationConfig.from_pretrained(
181187
model, format=quantization_format
182188
)
183189

190+
# use config passed as argument
184191
if isinstance(sparsity_config, str): # we passed in a sparsity format
185192
sparsity_config = SparsityCompressionConfig.load_from_registry(
186193
sparsity_config
187194
)
188195

189-
if sparsity_config is None and quantization_config is None:
196+
# use config attached to model
197+
transform_config = getattr(model, TRANSFORM_CONFIG_NAME, None)
198+
199+
if not any((quantization_config, sparsity_config, transform_config)):
190200
return None
191201

192202
return cls(
193-
sparsity_config=sparsity_config, quantization_config=quantization_config
203+
sparsity_config=sparsity_config,
204+
quantization_config=quantization_config,
205+
transform_config=transform_config,
194206
)
195207

196208
@staticmethod
@@ -254,13 +266,17 @@ def __init__(
254266
self,
255267
sparsity_config: Optional[SparsityCompressionConfig] = None,
256268
quantization_config: Optional[QuantizationConfig] = None,
269+
transform_config: Optional[TransformConfig] = None,
257270
):
258271
self.sparsity_config = sparsity_config
259272
self.quantization_config = quantization_config
273+
self.transform_config = transform_config
274+
260275
self.sparsity_compressor = None
261276
self.quantization_compressor: Optional[
262277
Union[BaseQuantizationCompressor, DenseCompressor]
263278
] = None
279+
# no transform compressor is required
264280

265281
if sparsity_config is not None:
266282
self.sparsity_compressor = BaseCompressor.load_from_registry(
@@ -640,43 +656,49 @@ def update_config(self, save_directory: str):
640656
641657
:param save_directory: path to a folder containing a HF model config
642658
"""
643-
if self.quantization_config is None and self.sparsity_config is None:
659+
# this check is also done in `from_pretrained_model`,
660+
# but not in `from_pretrained`` or `from_compression_config``
661+
if not any(
662+
(self.quantization_config, self.sparsity_config, self.transform_config)
663+
):
644664
return
645665

666+
# write to config.json file, regardless of whether it exists already
667+
# overwrite previous config and version if already existing
646668
config_file_path = os.path.join(save_directory, CONFIG_NAME)
647-
if not os.path.exists(config_file_path):
648-
_LOGGER.warning(
649-
f"Could not find a valid model config file in "
650-
f"{save_directory}. Compression config will not be saved."
651-
)
652-
return
669+
if os.path.exists(config_file_path):
670+
with open(config_file_path, "r") as file:
671+
config_data = json.load(file)
672+
else:
673+
config_data = {}
653674

654-
with open(config_file_path, "r") as config_file:
655-
config_data = json.load(config_file)
675+
# serialize configs into json
676+
qconfig_data = (
677+
self.quantization_config.model_dump(exclude=["quant_method", "format"])
678+
if self.quantization_config is not None
679+
else {}
680+
)
681+
sconfig_data = (
682+
self.sparsity_config.model_dump()
683+
if self.sparsity_config is not None
684+
else {}
685+
)
686+
tconfig_data = (
687+
self.transform_config.model_dump()
688+
if self.transform_config is not None
689+
else {}
690+
)
656691

657-
# required metadata whenever a quantization or sparsity config is present
658-
# overwrite previous config and version if already existing
659-
config_data[QUANTIZATION_CONFIG_NAME] = {}
660-
config_data[QUANTIZATION_CONFIG_NAME][
661-
COMPRESSION_VERSION_NAME
662-
] = compressed_tensors.__version__
663-
if self.quantization_config is not None:
664-
self.quantization_config.quant_method = DEFAULT_QUANTIZATION_METHOD
665-
else:
666-
config_data[QUANTIZATION_CONFIG_NAME][
667-
QUANTIZATION_METHOD_NAME
668-
] = DEFAULT_QUANTIZATION_METHOD
669-
670-
# quantization and sparsity configs
671-
if self.quantization_config is not None:
672-
quant_config_data = self.quantization_config.model_dump()
673-
config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
674-
if self.sparsity_config is not None:
675-
sparsity_config_data = self.sparsity_config.model_dump()
676-
config_data[QUANTIZATION_CONFIG_NAME][
677-
SPARSITY_CONFIG_NAME
678-
] = sparsity_config_data
692+
# construct compression (quantization) config
693+
config_data[QUANTIZATION_CONFIG_NAME] = {
694+
COMPRESSION_VERSION_NAME: compressed_tensors.__version__,
695+
QUANTIZATION_METHOD_NAME: DEFAULT_QUANTIZATION_METHOD,
696+
SPARSITY_CONFIG_NAME: sconfig_data,
697+
TRANSFORM_CONFIG_NAME: tconfig_data,
698+
**qconfig_data,
699+
}
679700

701+
# write results to config.json file
680702
with open(config_file_path, "w") as config_file:
681703
json.dump(config_data, config_file, indent=2, sort_keys=True)
682704

src/compressed_tensors/transform/apply.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import torch
16+
from compressed_tensors import TRANSFORM_CONFIG_NAME
1617
from compressed_tensors.transform import TransformConfig, TransformFactory
1718

1819

@@ -30,3 +31,6 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
3031
for name, scheme in config.config_groups.items():
3132
factory = TransformFactory.from_scheme(scheme, name=name)
3233
factory.apply_to_model(model)
34+
35+
# attach config to model for compression/serialization
36+
setattr(model, TRANSFORM_CONFIG_NAME, config)

src/compressed_tensors/transform/factory/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@
1414

1515
from abc import ABC, abstractmethod
1616
from collections import defaultdict
17-
from typing import List, Optional, Tuple, Set
17+
from typing import List, Optional, Set, Tuple
1818

1919
import torch
2020
import torch.nn.utils.parametrize as P
21-
from compressed_tensors import InternalModule
2221
from compressed_tensors.registry.registry import RegistryMixin, T
2322
from compressed_tensors.transform import (
2423
TransformArgs,
@@ -34,6 +33,7 @@
3433
register_offload_module,
3534
update_offload_parameter,
3635
)
36+
from compressed_tensors.utils.internal import InternalModule
3737
from torch import Tensor
3838
from torch.nn import Module, Parameter
3939

0 commit comments

Comments
 (0)