Skip to content

Commit 6800382

Browse files
authored
add get_missing_module_keys to support transformers lower bound (#479)
1 parent 0d8c7c3 commit 6800382

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,55 @@ def __init__(
344344
format, config=quantization_config
345345
)
346346

347+
def get_missing_module_keys(self, model: Module) -> List[str]:
348+
"""
349+
Identifies the expected missing weight keys in the compressed state_dict.
350+
351+
When a model undergoes sparsity or quantization compression, certain
352+
weight tensors may be absent from the checkpoint by virtue of compression.
353+
This function determines which weight keys are missing based on the
354+
applied compression techniques.
355+
356+
:param model: The PyTorch model to check for missing keys.
357+
:return: A list of missing keys expected in the compressed state_dict.
358+
"""
359+
missing_keys = set()
360+
361+
# Determine missing keys due to sparsity compression
362+
if (
363+
self.sparsity_compressor
364+
and self.sparsity_config.format != CompressionFormat.dense.value
365+
):
366+
sparse_targets = match_named_modules(
367+
model=model,
368+
targets=self.sparsity_config.targets,
369+
ignore=self.sparsity_config.ignore,
370+
)
371+
372+
missing_keys.update(
373+
merge_names(target_name, "weight")
374+
for target_name, _module in sparse_targets
375+
)
376+
377+
# Determine missing keys due to pack quantization
378+
if (
379+
self.quantization_compressor
380+
and self.quantization_config.format
381+
== CompressionFormat.pack_quantized.value
382+
):
383+
for scheme in self.quantization_config.config_groups.values():
384+
quant_targets = match_named_modules(
385+
model=model,
386+
targets=scheme.targets,
387+
ignore=self.quantization_config.ignore,
388+
)
389+
missing_keys.update(
390+
merge_names(target_name, "weight")
391+
for target_name, _module in quant_targets
392+
)
393+
394+
return list(missing_keys)
395+
347396
def get_unexpected_file_keys(self, model: Module) -> List[str]:
348397
"""
349398
Identifies extra keys introduced by the compression process in the

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,26 @@ def test_multiple_quant_compressors():
450450
assert all(format in compressor.quantization_compressor for format in formats)
451451

452452

453+
@pytest.mark.parametrize(
454+
"model, sparsity_config, quantization_config, expected",
455+
[
456+
(
457+
TwoLayerModel(),
458+
get_bitmask_sparsity_config(targets=["re:.*layer1$"]),
459+
create_quantization_config(bits=8, type="int", strategy="channel"),
460+
{"layer1.weight"},
461+
)
462+
],
463+
)
464+
def test_get_missing_keys(model, sparsity_config, quantization_config, expected):
465+
model_compressor = ModelCompressor(
466+
sparsity_config=sparsity_config, quantization_config=quantization_config
467+
)
468+
469+
actual = model_compressor.get_missing_module_keys(model)
470+
assert len(actual) == len(expected) and all(key in actual for key in expected)
471+
472+
453473
@pytest.mark.parametrize(
454474
"model, sparsity_config, quantization_config, expected",
455475
[

0 commit comments

Comments
 (0)