Skip to content

Commit d2daa9a

Browse files
authored
[ModelCompressor] Remove missing keys and missing modules (#462)
* remove missing keys and missing modules * format, remove tests * format
1 parent 9170fb3 commit d2daa9a

File tree

2 files changed

+0
-162
lines changed

2 files changed

+0
-162
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 0 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
get_offloaded_device,
5151
get_safetensors_folder,
5252
has_offloaded_params,
53-
merge_names,
5453
register_offload_parameter,
5554
update_parameter_data,
5655
)
@@ -321,112 +320,6 @@ def __init__(
321320
format, config=quantization_config
322321
)
323322

324-
# ----- used by hf quantizer ----- #
325-
326-
def get_missing_module_keys(self, model: Module) -> List[str]:
327-
"""
328-
Identifies the expected missing weight keys in the compressed state_dict.
329-
330-
When a model undergoes sparsity or quantization compression, certain
331-
weight tensors may be absent from the checkpoint by virtue of compression.
332-
This function determines which weight keys are missing based on the
333-
applied compression techniques.
334-
335-
:param model: The PyTorch model to check for missing keys.
336-
:return: A list of missing keys expected in the compressed state_dict.
337-
"""
338-
missing_keys = set()
339-
340-
# Determine missing keys due to sparsity compression
341-
if (
342-
self.sparsity_compressor
343-
and self.sparsity_config.format != CompressionFormat.dense.value
344-
):
345-
sparse_targets = match_named_modules(
346-
model=model,
347-
targets=self.sparsity_config.targets,
348-
ignore=self.sparsity_config.ignore,
349-
)
350-
351-
missing_keys.update(
352-
merge_names(target_name, "weight")
353-
for target_name, _module in sparse_targets
354-
)
355-
356-
# Determine missing keys due to pack quantization
357-
if (
358-
self.quantization_compressor
359-
and self.quantization_config.format
360-
== CompressionFormat.pack_quantized.value
361-
):
362-
for scheme in self.quantization_config.config_groups.values():
363-
quant_targets = match_named_modules(
364-
model=model,
365-
targets=scheme.targets,
366-
ignore=self.quantization_config.ignore,
367-
)
368-
missing_keys.update(
369-
merge_names(target_name, "weight")
370-
for target_name, _module in quant_targets
371-
)
372-
373-
return list(missing_keys)
374-
375-
def get_unexpected_file_keys(self, model: Module) -> List[str]:
376-
"""
377-
Identifies extra keys introduced by the compression process in the
378-
compressed state_dict that are not expected by the model graph.
379-
380-
During sparsity or quantization compression, additional metadata or
381-
auxiliary parameters may be stored in the checkpoint, which do not
382-
correspond to any parameter in the original model. These keys are
383-
typically introduced to support the reconstruction of compressed weights.
384-
385-
For example, Sparse24Bitmask compression may introduce keys such as
386-
'compressed', 'bitmask', and 'shape' in the checkpoint, which are
387-
not part of the original model parameters.
388-
389-
:param model: The PyTorch model to check for unexpected keys.
390-
:return: A list of extra keys introduced by the compression process
391-
that are not expected by the model.
392-
"""
393-
394-
unexpected_keys = set()
395-
396-
# Identify unexpected keys from sparsity compression
397-
if (
398-
self.sparsity_compressor
399-
and self.sparsity_config.format != CompressionFormat.dense.value
400-
):
401-
sparse_targets = match_named_modules(
402-
model=model,
403-
targets=self.sparsity_config.targets,
404-
ignore=self.sparsity_config.ignore,
405-
)
406-
unexpected_keys.update(
407-
merge_names(target_name, param)
408-
for target_name, _module in sparse_targets
409-
for param in self.sparsity_compressor.compression_param_names
410-
)
411-
412-
# Identify unexpected keys from quantization compression
413-
if self.quantization_compressor:
414-
for scheme in self.quantization_config.config_groups.values():
415-
quant_targets = match_named_modules(
416-
model=model,
417-
targets=scheme.targets,
418-
ignore=self.quantization_config.ignore,
419-
)
420-
for quant_compressor in self.quantization_compressor.values():
421-
unexpected_keys.update(
422-
merge_names(target_name, param)
423-
for target_name, _module in quant_targets
424-
for param in quant_compressor.compression_param_names
425-
if param != "weight"
426-
)
427-
428-
return list(unexpected_keys)
429-
430323
# ----- model memory compression/decompression pathways ----- #
431324

432325
def compress_model(self, model: Module):

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -253,61 +253,6 @@ def forward(self, x):
253253
return x
254254

255255

256-
@pytest.mark.parametrize(
257-
"model, sparsity_config, quantization_config, expected",
258-
[
259-
(
260-
TwoLayerModel(),
261-
get_bitmask_sparsity_config(targets=["re:.*layer1$"]),
262-
create_quantization_config(bits=8, type="int", strategy="channel"),
263-
{"layer1.weight"},
264-
)
265-
],
266-
)
267-
def test_get_missing_keys(model, sparsity_config, quantization_config, expected):
268-
model_compressor = ModelCompressor(
269-
sparsity_config=sparsity_config, quantization_config=quantization_config
270-
)
271-
272-
actual = model_compressor.get_missing_module_keys(model)
273-
assert len(actual) == len(expected) and all(key in actual for key in expected)
274-
275-
276-
@pytest.mark.parametrize(
277-
"model, sparsity_config, quantization_config, expected",
278-
[
279-
(
280-
TwoLayerModel(),
281-
get_bitmask_sparsity_config(targets=["re:.*layer1$"]),
282-
create_quantization_config(bits=8, type="int", strategy="channel"),
283-
{
284-
f"{layer}.{suffix}"
285-
for layer, suffixes in {
286-
"layer1": [
287-
"shape",
288-
"row_offsets",
289-
"weight_zero_point",
290-
"weight_g_idx",
291-
"bitmask",
292-
"weight_scale",
293-
"compressed",
294-
],
295-
"layer2": ["weight_scale", "weight_zero_point", "weight_g_idx"],
296-
}.items()
297-
for suffix in suffixes
298-
},
299-
)
300-
],
301-
)
302-
def test_get_unexpected_keys(model, sparsity_config, quantization_config, expected):
303-
model_compressor = ModelCompressor(
304-
sparsity_config=sparsity_config, quantization_config=quantization_config
305-
)
306-
307-
actual = model_compressor.get_unexpected_file_keys(model)
308-
assert len(actual) == len(expected) and all(key in actual for key in expected)
309-
310-
311256
def _create_dummy_checkpoint(state_dict, save_dir, model_compressor):
312257
save_dir = Path(save_dir)
313258
save_dir.mkdir(parents=True, exist_ok=True)

0 commit comments

Comments
 (0)