Skip to content

Commit 8f514c4

Browse files
committed
update
1 parent c49a94b commit 8f514c4

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def from_pretrained_model(
180180
# assume multiple compression formats means mixed-precision
181181
# as we currently only support one compressor per precision type and scheme
182182
if len(quantization_format) > 1:
183-
quantization_format = CompressionFormat.mixed_precision
183+
quantization_format = CompressionFormat.mixed_precision.value
184184
else:
185185
quantization_format = quantization_format[0]
186186

@@ -258,15 +258,15 @@ def parse_quantization_config(
258258

259259
return quantization_config
260260

261-
def _fetch_unique_quantization_formats(self):
261+
def _fetch_unique_quantization_formats(self) -> List[str]:
262262
"""
263-
Get all unique compression formats used in
264-
model
263+
Get all unique compression formats present in a model
264+
:return: list of quantization formats
265265
"""
266266
quantization_formats = []
267267
for _, scheme in self.quantization_config.config_groups.items():
268268
if scheme.format not in quantization_formats:
269-
quantization_formats.append(scheme)
269+
quantization_formats.append(scheme.format)
270270
return quantization_formats
271271

272272
def __init__(
@@ -287,7 +287,6 @@ def __init__(
287287
)
288288

289289
quantization_formats = self._fetch_unique_quantization_formats()
290-
291290
if quantization_config is not None:
292291
self.quantization_compressor = {}
293292
for format in quantization_formats:
@@ -694,6 +693,7 @@ def update_config(self, save_directory: str):
694693
config_data[QUANTIZATION_CONFIG_NAME][
695694
COMPRESSION_VERSION_NAME
696695
] = compressed_tensors.__version__
696+
697697
if self.quantization_config is not None:
698698
self.quantization_config.quant_method = DEFAULT_QUANTIZATION_METHOD
699699
else:

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class QuantizationScheme(BaseModel):
4949
weights: Optional[QuantizationArgs] = None
5050
input_activations: Optional[QuantizationArgs] = None
5151
output_activations: Optional[QuantizationArgs] = None
52-
format: Optional[CompressionFormat] = None
52+
format: Optional[str] = None
5353

5454
@model_validator(mode="after")
5555
def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":

0 commit comments

Comments
 (0)