@@ -344,6 +344,55 @@ def __init__(
344
344
format , config = quantization_config
345
345
)
346
346
347
+ def get_missing_module_keys (self , model : Module ) -> List [str ]:
348
+ """
349
+ Identifies the expected missing weight keys in the compressed state_dict.
350
+
351
+ When a model undergoes sparsity or quantization compression, certain
352
+ weight tensors may be absent from the checkpoint by virtue of compression.
353
+ This function determines which weight keys are missing based on the
354
+ applied compression techniques.
355
+
356
+ :param model: The PyTorch model to check for missing keys.
357
+ :return: A list of missing keys expected in the compressed state_dict.
358
+ """
359
+ missing_keys = set ()
360
+
361
+ # Determine missing keys due to sparsity compression
362
+ if (
363
+ self .sparsity_compressor
364
+ and self .sparsity_config .format != CompressionFormat .dense .value
365
+ ):
366
+ sparse_targets = match_named_modules (
367
+ model = model ,
368
+ targets = self .sparsity_config .targets ,
369
+ ignore = self .sparsity_config .ignore ,
370
+ )
371
+
372
+ missing_keys .update (
373
+ merge_names (target_name , "weight" )
374
+ for target_name , _module in sparse_targets
375
+ )
376
+
377
+ # Determine missing keys due to pack quantization
378
+ if (
379
+ self .quantization_compressor
380
+ and self .quantization_config .format
381
+ == CompressionFormat .pack_quantized .value
382
+ ):
383
+ for scheme in self .quantization_config .config_groups .values ():
384
+ quant_targets = match_named_modules (
385
+ model = model ,
386
+ targets = scheme .targets ,
387
+ ignore = self .quantization_config .ignore ,
388
+ )
389
+ missing_keys .update (
390
+ merge_names (target_name , "weight" )
391
+ for target_name , _module in quant_targets
392
+ )
393
+
394
+ return list (missing_keys )
395
+
347
396
def get_unexpected_file_keys (self , model : Module ) -> List [str ]:
348
397
"""
349
398
Identifies extra keys introduced by the compression process in the
0 commit comments