Skip to content

Commit b5cd4e7

Browse files
committed
add tests
1 parent 8b5d4c9 commit b5cd4e7

File tree

3 files changed

+52
-4
lines changed

3 files changed

+52
-4
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def from_pretrained_model(
189189
if isinstance(quantization_format, (str, CompressionFormat)):
190190
quantization_format = [quantization_format]
191191

192+
compression_formats = quantization_format
192193
# assume multiple compression formats means mixed-precision
193194
# as we currently only support one compressor per precision type and scheme
194195
if len(quantization_format) > 1:
@@ -216,6 +217,7 @@ def from_pretrained_model(
216217
sparsity_config=sparsity_config,
217218
quantization_config=quantization_config,
218219
transform_config=transform_config,
220+
compression_formats=compression_formats,
219221
)
220222

221223
@staticmethod
@@ -296,10 +298,12 @@ def __init__(
296298
sparsity_config: Optional[SparsityCompressionConfig] = None,
297299
quantization_config: Optional[QuantizationConfig] = None,
298300
transform_config: Optional[TransformConfig] = None,
301+
compression_formats: Optional[List[str]] = None,
299302
):
300303
self.sparsity_config = sparsity_config
301304
self.quantization_config = quantization_config
302305
self.transform_config = transform_config
306+
self.compression_formats = compression_formats
303307

304308
self.sparsity_compressor = None
305309
self.quantization_compressor: Optional[
@@ -313,9 +317,11 @@ def __init__(
313317
)
314318

315319
if quantization_config is not None:
316-
quantization_formats = self._fetch_unique_quantization_formats()
320+
if not self.compression_formats:
321+
self.compression_formats = self._fetch_unique_quantization_formats()
322+
317323
self.quantization_compressor = {}
318-
for format in quantization_formats:
324+
for format in self.compression_formats:
319325
self.quantization_compressor[
320326
format
321327
] = BaseCompressor.load_from_registry(

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class QuantizationScheme(BaseModel):
4343
:param weights: quantization config for layer weights
4444
:param input_activations: quantization config for layer inputs
4545
:param output_activations: quantization config for layer outputs
46+
:param format: CompressionFormat for the layer
4647
"""
4748

4849
targets: List[str]

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,12 @@
2020
import torch
2121
import torch.nn as nn
2222
from compressed_tensors.compressors import ModelCompressor
23-
from compressed_tensors.config import SparsityCompressionConfig
24-
from compressed_tensors.quantization import QuantizationConfig
23+
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
24+
from compressed_tensors.quantization import (
25+
QuantizationArgs,
26+
QuantizationConfig,
27+
QuantizationScheme,
28+
)
2529
from safetensors.torch import save_file
2630
from tests.testing_utils import induce_sparsity, requires_hf_quantizer
2731
from transformers import AutoModelForCausalLM
@@ -469,6 +473,43 @@ def test_compress_model_meta(model_stub, q_format, s_config):
469473
assert compressed[key].dtype == dtype, f"{key} has incorrect dtype"
470474

471475

476+
def test_multiple_quant_compressors():
477+
model = torch.nn.Sequential(torch.nn.Linear(1, 2), torch.nn.Linear(2, 3))
478+
input_activations = QuantizationArgs(num_bits=8, type="float")
479+
weights = QuantizationArgs(num_bits=8, type="float")
480+
481+
scheme_fp8 = QuantizationScheme(
482+
targets=["Linear"],
483+
weights=weights,
484+
input_activations=input_activations,
485+
format=CompressionFormat.float_quantized.value,
486+
)
487+
488+
input_activations = QuantizationArgs(num_bits=4, type="float")
489+
weights = QuantizationArgs(num_bits=4, type="float")
490+
491+
scheme_nvfp4 = QuantizationScheme(
492+
targets=["Linear"],
493+
weights=weights,
494+
input_activations=input_activations,
495+
format=CompressionFormat.nvfp4_pack_quantized.value,
496+
)
497+
498+
model[0].quantization_scheme = scheme_fp8
499+
model[0].quantization_status = "frozen"
500+
model[1].quantization_scheme = scheme_nvfp4
501+
model[1].quantization_status = "frozen"
502+
503+
formats = [scheme_fp8.format, scheme_nvfp4.format]
504+
505+
compressor = ModelCompressor.from_pretrained_model(model, None, formats)
506+
assert isinstance(compressor.quantization_compressor, dict)
507+
assert (
508+
compressor.quantization_config.format == CompressionFormat.mixed_precision.value
509+
)
510+
assert all(format in compressor.quantization_compressor for format in formats)
511+
512+
472513
@pytest.mark.parametrize(
473514
"model_stub,comp_stub",
474515
[

0 commit comments

Comments
 (0)