Skip to content

Commit c688c79

Browse files
authored
[Model Compressor] Move infer call to from_pretrained_model method (#470)
* move infer call to model * global format support * fix * more clena-up * fix type hint * update * docstring * sqap * update * update * update
1 parent dfd069b commit c688c79

File tree

4 files changed

+70
-29
lines changed

4 files changed

+70
-29
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
from compressed_tensors.compressors.base import BaseCompressor
3434
from compressed_tensors.compressors.sparse_compressors import DenseCompressor
3535
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
36+
from compressed_tensors.config.format import (
37+
infer_and_set_per_module_quantization_format,
38+
)
3639
from compressed_tensors.quantization import (
3740
DEFAULT_QUANTIZATION_METHOD,
3841
QuantizationConfig,
@@ -58,6 +61,7 @@
5861
is_compressed_tensors_config,
5962
)
6063
from compressed_tensors.utils.match import match_named_modules
64+
from loguru import logger
6165
from torch import Tensor
6266
from torch.nn import Module
6367
from tqdm import tqdm
@@ -166,29 +170,50 @@ def from_compression_config(
166170
def from_pretrained_model(
167171
cls,
168172
model: Module,
173+
sparsity_config_or_format: Union[SparsityCompressionConfig, str, None] = None,
174+
quantization_format: Optional[str] = None,
169175
sparsity_config: Union[SparsityCompressionConfig, str, None] = None,
170-
quantization_format: Optional[Union[str, List[str]]] = None,
171176
) -> Optional["ModelCompressor"]:
172177
"""
173178
Given a pytorch model and optional sparsity and/or quantization configs,
174179
load the appropriate compressors
175180
176181
:param model: pytorch model to target for compression
177182
:param sparsity_config: a filled in sparsity config or string corresponding
178-
to a sparsity compression algorithm
179-
:param quantization_format: string corresponding to a quantization compression
180-
algorithm
183+
to a sparsity format
184+
:param quantization_format: string corresponding to a quantization
185+
format that should be applied to the entire model
181186
:return: compressor for the configs, or None if model is not compressed
182187
"""
183-
quantization_config = QuantizationConfig.from_pretrained(
184-
model, format=quantization_format
185-
)
188+
if sparsity_config:
189+
logger.warning(
190+
"sparsity_config is deprecated, use sparsity_config_or_format"
191+
)
192+
sparsity_config_or_format = sparsity_config
186193

187-
# use config passed as argument
188-
if isinstance(sparsity_config, str): # we passed in a sparsity format
194+
if sparsity_config_or_format and isinstance(
195+
sparsity_config_or_format, str
196+
): # we passed in a sparsity format
189197
sparsity_config = SparsityCompressionConfig.load_from_registry(
190-
sparsity_config
198+
sparsity_config_or_format
191199
)
200+
else:
201+
# otherwise, config or None
202+
sparsity_config = sparsity_config_or_format
203+
204+
quantization_format = infer_and_set_per_module_quantization_format(
205+
model=model,
206+
sparsity_structure=(
207+
sparsity_config.sparsity_structure
208+
if sparsity_config is not None
209+
else None
210+
),
211+
quantization_format=quantization_format,
212+
)
213+
214+
quantization_config = QuantizationConfig.from_pretrained(
215+
model, format=quantization_format
216+
)
192217

193218
# use config attached to model
194219
transform_config = getattr(model, TRANSFORM_CONFIG_NAME, None)
@@ -200,11 +225,7 @@ def from_pretrained_model(
200225
sparsity_config=sparsity_config,
201226
quantization_config=quantization_config,
202227
transform_config=transform_config,
203-
compression_formats=(
204-
[quantization_format]
205-
if isinstance(quantization_format, str)
206-
else quantization_format
207-
),
228+
compression_formats=quantization_format,
208229
)
209230

210231
@staticmethod
@@ -620,6 +641,7 @@ def decompress(self, model_path: str, model: Module):
620641
# compressor in a follow-up including initialization
621642
load_weight_qparams=load_weight_qparams,
622643
)
644+
623645
model_path_or_state_dict = (
624646
model.state_dict() if sparse_decompressed else model_path
625647
)

src/compressed_tensors/config/format.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,18 @@ def _get_quant_compression_format(
8282

8383

8484
def set_per_module_format(
85-
module: torch.nn.Module, sparsity_structure: Optional[str] = None
85+
module: torch.nn.Module,
86+
sparsity_structure: Optional[str] = None,
87+
quantization_format: Optional[str] = None,
8688
):
8789
"""
8890
Determine and set the per module quantization format given quantization args
8991
and sparsity structure.
9092
9193
:param module: module which has its quantization inferred
9294
:param sparsity_structure: optional sparsity applied to the module
95+
:param quantization_format: optional global format to override
96+
the per module formats
9397
9498
"""
9599
weight_scheme = module.quantization_scheme.weights
@@ -100,41 +104,56 @@ def set_per_module_format(
100104
input_scheme, weight_scheme, sparsity_structure
101105
)
102106

103-
# If set, we check if it matches our inferred one
104-
if module.quantization_scheme.format is not None:
107+
# Check if a global format was provided first
108+
# This will override any per module format
109+
if quantization_format is not None:
110+
if quantization_format != compression_format.value:
111+
logger.warning(
112+
"The provided format for the module does not match the "
113+
"inferred format. Compression may fail "
114+
)
115+
module.quantization_scheme.format = quantization_format
116+
# If a per module format is not provided, we check if it matches our inferred one
117+
elif module.quantization_scheme.format is not None:
105118
# If it does not, warn the user
106119
if module.quantization_scheme.format != compression_format.value:
107120
logger.warning(
108121
"The provided format for the module does not match the "
109122
"inferred format. Compression may fail "
110123
)
124+
# If neither provided, set ours
111125
else:
112-
# If not set, we set ours
113126
module.quantization_scheme.format = compression_format.value
114127

115128

116129
def infer_and_set_per_module_quantization_format(
117130
model: torch.nn.Module,
118131
sparsity_structure: Optional[str] = None,
132+
quantization_format: Optional[str] = None,
119133
) -> List[str]:
120134
"""
121135
Infers the quantization format for a model based on its state and provided
122136
compression arguments. Updates thhe quantization_scheme.format value
123-
based on the inferred format. Returns the unique list of formats in the model
124-
or None if empty list
137+
based on the inferred format. Returns the unique list of formats in the model.
138+
All None formats are mapped to CompressionFormat.dense.value
125139
126140
For a summary of the formats, see `docs/guides/compression_formats.md`.
127141
128142
:param model: model to check for quantization
129143
:param sparsity_structure: optional sparsity applied to the module
130-
:return compression format appropriate for model
144+
:param quantization_format: optional global format to override
145+
the per module formats
146+
:return compression format appropriate for the model
131147
"""
132148
unique_formats = []
133149
for submodule in model.modules():
134150
if is_module_quantized(submodule):
135151
assert hasattr(submodule, "quantization_scheme")
136-
set_per_module_format(submodule, sparsity_structure)
137-
if submodule.quantization_scheme.format not in unique_formats:
152+
set_per_module_format(submodule, sparsity_structure, quantization_format)
153+
if (
154+
submodule.quantization_scheme.format
155+
and submodule.quantization_scheme.format not in unique_formats
156+
):
138157
unique_formats.append(submodule.quantization_scheme.format)
139158

140159
if len(unique_formats) > 0:

src/compressed_tensors/quantization/quant_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def to_dict(self):
165165

166166
@staticmethod
167167
def from_pretrained(
168-
model: Module, format: Optional[str] = None
168+
model: Module, format: Optional[Union[str, list]] = None
169169
) -> Optional["QuantizationConfig"]:
170170
"""
171171
Converts a model into its associated QuantizationConfig based on the

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def _get_combined_config(s_config, q_config):
342342
)
343343
def test_compress_model(model_stub, q_format, s_config, tmpdir):
344344
model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.float32)
345-
compressor = ModelCompressor.from_pretrained_model(model, s_config, [q_format])
345+
compressor = ModelCompressor.from_pretrained_model(model, s_config, q_format)
346346

347347
# compress model by eagerly compressing state dict
348348
true_compressed = dict(compressor.compress(model))
@@ -388,7 +388,7 @@ def test_compress_model_meta(model_stub, q_format, s_config):
388388
# Load model on CPU to get expected compressed state_dict
389389
cpu_model = AutoModelForCausalLM.from_pretrained(model_stub)
390390
reference_compressor = ModelCompressor.from_pretrained_model(
391-
cpu_model, s_config, [q_format]
391+
cpu_model, s_config, q_format
392392
)
393393
# Only stores dtype because meta model does not store values
394394
expected = {k: v.dtype for k, v in reference_compressor.compress(cpu_model).items()}
@@ -403,7 +403,7 @@ def test_compress_model_meta(model_stub, q_format, s_config):
403403
module.to_empty(device="meta")
404404

405405
# Compress in-place on meta model
406-
compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, [q_format])
406+
compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, q_format)
407407
compressor.compress_model(meta_model)
408408

409409
# Compare keys and dtypes
@@ -442,7 +442,7 @@ def test_multiple_quant_compressors():
442442

443443
formats = [scheme_fp8.format, scheme_nvfp4.format]
444444

445-
compressor = ModelCompressor.from_pretrained_model(model, None, formats)
445+
compressor = ModelCompressor.from_pretrained_model(model, None)
446446
assert isinstance(compressor.quantization_compressor, dict)
447447
assert (
448448
compressor.quantization_config.format == CompressionFormat.mixed_precision.value

0 commit comments

Comments
 (0)