Skip to content

Commit e85d14a

Browse files
authored
add back get missing keys to support transformers lower bound (#475)
1 parent 3e4d7fa commit e85d14a

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
get_offloaded_device,
5353
get_safetensors_folder,
5454
has_offloaded_params,
55+
merge_names,
5556
patch_attr,
5657
register_offload_parameter,
5758
update_parameter_data,
@@ -343,6 +344,61 @@ def __init__(
343344
format, config=quantization_config
344345
)
345346

347+
def get_unexpected_file_keys(self, model: Module) -> List[str]:
348+
"""
349+
Identifies extra keys introduced by the compression process in the
350+
compressed state_dict that are not expected by the model graph.
351+
352+
During sparsity or quantization compression, additional metadata or
353+
auxiliary parameters may be stored in the checkpoint, which do not
354+
correspond to any parameter in the original model. These keys are
355+
typically introduced to support the reconstruction of compressed weights.
356+
357+
For example, Sparse24Bitmask compression may introduce keys such as
358+
'compressed', 'bitmask', and 'shape' in the checkpoint, which are
359+
not part of the original model parameters.
360+
361+
:param model: The PyTorch model to check for unexpected keys.
362+
:return: A list of extra keys introduced by the compression process
363+
that are not expected by the model.
364+
"""
365+
366+
unexpected_keys = set()
367+
368+
# Identify unexpected keys from sparsity compression
369+
if (
370+
self.sparsity_compressor
371+
and self.sparsity_config.format != CompressionFormat.dense.value
372+
):
373+
sparse_targets = match_named_modules(
374+
model=model,
375+
targets=self.sparsity_config.targets,
376+
ignore=self.sparsity_config.ignore,
377+
)
378+
unexpected_keys.update(
379+
merge_names(target_name, param)
380+
for target_name, _module in sparse_targets
381+
for param in self.sparsity_compressor.compression_param_names
382+
)
383+
384+
# Identify unexpected keys from quantization compression
385+
if self.quantization_compressor:
386+
for scheme in self.quantization_config.config_groups.values():
387+
quant_targets = match_named_modules(
388+
model=model,
389+
targets=scheme.targets,
390+
ignore=self.quantization_config.ignore,
391+
)
392+
for quant_compressor in self.quantization_compressor.values():
393+
unexpected_keys.update(
394+
merge_names(target_name, param)
395+
for target_name, _module in quant_targets
396+
for param in quant_compressor.compression_param_names
397+
if param != "weight"
398+
)
399+
400+
return list(unexpected_keys)
401+
346402
# ----- model memory compression/decompression pathways ----- #
347403

348404
def compress_model(self, model: Module):

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,41 @@ def test_multiple_quant_compressors():
450450
assert all(format in compressor.quantization_compressor for format in formats)
451451

452452

453+
@pytest.mark.parametrize(
454+
"model, sparsity_config, quantization_config, expected",
455+
[
456+
(
457+
TwoLayerModel(),
458+
get_bitmask_sparsity_config(targets=["re:.*layer1$"]),
459+
create_quantization_config(bits=8, type="int", strategy="channel"),
460+
{
461+
f"{layer}.{suffix}"
462+
for layer, suffixes in {
463+
"layer1": [
464+
"shape",
465+
"row_offsets",
466+
"weight_zero_point",
467+
"weight_g_idx",
468+
"bitmask",
469+
"weight_scale",
470+
"compressed",
471+
],
472+
"layer2": ["weight_scale", "weight_zero_point", "weight_g_idx"],
473+
}.items()
474+
for suffix in suffixes
475+
},
476+
)
477+
],
478+
)
479+
def test_get_unexpected_keys(model, sparsity_config, quantization_config, expected):
480+
model_compressor = ModelCompressor(
481+
sparsity_config=sparsity_config, quantization_config=quantization_config
482+
)
483+
484+
actual = model_compressor.get_unexpected_file_keys(model)
485+
assert len(actual) == len(expected) and all(key in actual for key in expected)
486+
487+
453488
@pytest.mark.parametrize(
454489
"model_stub,comp_stub",
455490
[

0 commit comments

Comments
 (0)