Skip to content

Commit 667310f

Browse files
committed
Revert "[Transform] Serialize transforms config (#412)"
This reverts commit 0731aa5.
1 parent de945c6 commit 667310f

File tree

4 files changed

+38
-69
lines changed

4 files changed

+38
-69
lines changed

src/compressed_tensors/base.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# configs
16-
QUANTIZATION_CONFIG_NAME = "quantization_config"
1715
SPARSITY_CONFIG_NAME = "sparsity_config"
18-
TRANSFORM_CONFIG_NAME = "transform_config"
19-
20-
# required fields
16+
QUANTIZATION_CONFIG_NAME = "quantization_config"
17+
COMPRESSION_CONFIG_NAME = "compression_config"
18+
KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
2119
COMPRESSION_VERSION_NAME = "version"
2220
QUANTIZATION_METHOD_NAME = "quant_method"
23-
24-
# auxillary configs
25-
KV_CACHE_SCHEME_NAME = "kv_cache_scheme"

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 33 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
QUANTIZATION_CONFIG_NAME,
3030
QUANTIZATION_METHOD_NAME,
3131
SPARSITY_CONFIG_NAME,
32-
TRANSFORM_CONFIG_NAME,
3332
)
3433
from compressed_tensors.compressors.base import BaseCompressor
3534
from compressed_tensors.compressors.sparse_compressors import DenseCompressor
@@ -44,7 +43,6 @@
4443
)
4544
from compressed_tensors.quantization.lifecycle import expand_target_names
4645
from compressed_tensors.quantization.utils import is_module_quantized
47-
from compressed_tensors.transform import TransformConfig
4846
from compressed_tensors.utils import (
4947
align_module_device,
5048
delete_offload_parameter,
@@ -107,7 +105,6 @@ class ModelCompressor:
107105

108106
sparsity_config: Optional[SparsityCompressionConfig] = None
109107
quantization_config: Optional[QuantizationConfig] = None
110-
transform_config: Optional[TransformConfig] = None
111108

112109
@classmethod
113110
def from_pretrained(
@@ -147,8 +144,6 @@ def from_compression_config(
147144

148145
sparsity_config = cls.parse_sparsity_config(compression_config)
149146
quantization_config = cls.parse_quantization_config(compression_config)
150-
# TODO: transform config is not support by CompressedTensorsConfig yet
151-
152147
if sparsity_config is None and quantization_config is None:
153148
return None
154149

@@ -182,27 +177,20 @@ def from_pretrained_model(
182177
algorithm
183178
:return: compressor for the configs, or None if model is not compressed
184179
"""
185-
# reconstruct config from schemes attached to modules
186180
quantization_config = QuantizationConfig.from_pretrained(
187181
model, format=quantization_format
188182
)
189183

190-
# use config passed as argument
191184
if isinstance(sparsity_config, str): # we passed in a sparsity format
192185
sparsity_config = SparsityCompressionConfig.load_from_registry(
193186
sparsity_config
194187
)
195188

196-
# use config attached to model
197-
transform_config = getattr(model, TRANSFORM_CONFIG_NAME, None)
198-
199-
if not any((quantization_config, sparsity_config, transform_config)):
189+
if sparsity_config is None and quantization_config is None:
200190
return None
201191

202192
return cls(
203-
sparsity_config=sparsity_config,
204-
quantization_config=quantization_config,
205-
transform_config=transform_config,
193+
sparsity_config=sparsity_config, quantization_config=quantization_config
206194
)
207195

208196
@staticmethod
@@ -266,17 +254,13 @@ def __init__(
266254
self,
267255
sparsity_config: Optional[SparsityCompressionConfig] = None,
268256
quantization_config: Optional[QuantizationConfig] = None,
269-
transform_config: Optional[TransformConfig] = None,
270257
):
271258
self.sparsity_config = sparsity_config
272259
self.quantization_config = quantization_config
273-
self.transform_config = transform_config
274-
275260
self.sparsity_compressor = None
276261
self.quantization_compressor: Optional[
277262
Union[BaseQuantizationCompressor, DenseCompressor]
278263
] = None
279-
# no transform compressor is required
280264

281265
if sparsity_config is not None:
282266
self.sparsity_compressor = BaseCompressor.load_from_registry(
@@ -656,49 +640,43 @@ def update_config(self, save_directory: str):
656640
657641
:param save_directory: path to a folder containing a HF model config
658642
"""
659-
# this check is also done in `from_pretrained_model`,
660-
# but not in `from_pretrained`` or `from_compression_config``
661-
if not any(
662-
(self.quantization_config, self.sparsity_config, self.transform_config)
663-
):
643+
if self.quantization_config is None and self.sparsity_config is None:
664644
return
665645

666-
# write to config.json file, regardless of whether it exists already
667-
# overwrite previous config and version if already existing
668646
config_file_path = os.path.join(save_directory, CONFIG_NAME)
669-
if os.path.exists(config_file_path):
670-
with open(config_file_path, "r") as file:
671-
config_data = json.load(file)
672-
else:
673-
config_data = {}
647+
if not os.path.exists(config_file_path):
648+
_LOGGER.warning(
649+
f"Could not find a valid model config file in "
650+
f"{save_directory}. Compression config will not be saved."
651+
)
652+
return
674653

675-
# serialize configs into json
676-
qconfig_data = (
677-
self.quantization_config.model_dump(exclude=["quant_method", "format"])
678-
if self.quantization_config is not None
679-
else {}
680-
)
681-
sconfig_data = (
682-
self.sparsity_config.model_dump()
683-
if self.sparsity_config is not None
684-
else {}
685-
)
686-
tconfig_data = (
687-
self.transform_config.model_dump()
688-
if self.transform_config is not None
689-
else {}
690-
)
654+
with open(config_file_path, "r") as config_file:
655+
config_data = json.load(config_file)
691656

692-
# construct compression (quantization) config
693-
config_data[QUANTIZATION_CONFIG_NAME] = {
694-
COMPRESSION_VERSION_NAME: compressed_tensors.__version__,
695-
QUANTIZATION_METHOD_NAME: DEFAULT_QUANTIZATION_METHOD,
696-
SPARSITY_CONFIG_NAME: sconfig_data,
697-
TRANSFORM_CONFIG_NAME: tconfig_data,
698-
**qconfig_data,
699-
}
657+
# required metadata whenever a quantization or sparsity config is present
658+
# overwrite previous config and version if already existing
659+
config_data[QUANTIZATION_CONFIG_NAME] = {}
660+
config_data[QUANTIZATION_CONFIG_NAME][
661+
COMPRESSION_VERSION_NAME
662+
] = compressed_tensors.__version__
663+
if self.quantization_config is not None:
664+
self.quantization_config.quant_method = DEFAULT_QUANTIZATION_METHOD
665+
else:
666+
config_data[QUANTIZATION_CONFIG_NAME][
667+
QUANTIZATION_METHOD_NAME
668+
] = DEFAULT_QUANTIZATION_METHOD
669+
670+
# quantization and sparsity configs
671+
if self.quantization_config is not None:
672+
quant_config_data = self.quantization_config.model_dump()
673+
config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
674+
if self.sparsity_config is not None:
675+
sparsity_config_data = self.sparsity_config.model_dump()
676+
config_data[QUANTIZATION_CONFIG_NAME][
677+
SPARSITY_CONFIG_NAME
678+
] = sparsity_config_data
700679

701-
# write results to config.json file
702680
with open(config_file_path, "w") as config_file:
703681
json.dump(config_data, config_file, indent=2, sort_keys=True)
704682

src/compressed_tensors/transform/apply.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import torch
16-
from compressed_tensors import TRANSFORM_CONFIG_NAME
1716
from compressed_tensors.transform import TransformConfig, TransformFactory
1817

1918

@@ -31,6 +30,3 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
3130
for name, scheme in config.config_groups.items():
3231
factory = TransformFactory.from_scheme(scheme, name=name)
3332
factory.apply_to_model(model)
34-
35-
# attach config to model for compression/serialization
36-
setattr(model, TRANSFORM_CONFIG_NAME, config)

src/compressed_tensors/transform/factory/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414

1515
from abc import ABC, abstractmethod
1616
from collections import defaultdict
17-
from typing import List, Optional, Set, Tuple
17+
from typing import List, Optional, Tuple, Set
1818

1919
import torch
2020
import torch.nn.utils.parametrize as P
21+
from compressed_tensors import InternalModule
2122
from compressed_tensors.registry.registry import RegistryMixin, T
2223
from compressed_tensors.transform import (
2324
TransformArgs,
@@ -33,7 +34,6 @@
3334
register_offload_module,
3435
update_offload_parameter,
3536
)
36-
from compressed_tensors.utils.internal import InternalModule
3737
from torch import Tensor
3838
from torch.nn import Module, Parameter
3939

0 commit comments

Comments
 (0)