|
52 | 52 | get_offloaded_device,
|
53 | 53 | get_safetensors_folder,
|
54 | 54 | has_offloaded_params,
|
| 55 | + merge_names, |
55 | 56 | patch_attr,
|
56 | 57 | register_offload_parameter,
|
57 | 58 | update_parameter_data,
|
@@ -343,6 +344,61 @@ def __init__(
|
343 | 344 | format, config=quantization_config
|
344 | 345 | )
|
345 | 346 |
|
| 347 | + def get_unexpected_file_keys(self, model: Module) -> List[str]: |
| 348 | + """ |
| 349 | + Identifies extra keys introduced by the compression process in the |
| 350 | + compressed state_dict that are not expected by the model graph. |
| 351 | +
|
| 352 | + During sparsity or quantization compression, additional metadata or |
| 353 | + auxiliary parameters may be stored in the checkpoint, which do not |
| 354 | + correspond to any parameter in the original model. These keys are |
| 355 | + typically introduced to support the reconstruction of compressed weights. |
| 356 | +
|
| 357 | + For example, Sparse24Bitmask compression may introduce keys such as |
| 358 | + 'compressed', 'bitmask', and 'shape' in the checkpoint, which are |
| 359 | + not part of the original model parameters. |
| 360 | +
|
| 361 | + :param model: The PyTorch model to check for unexpected keys. |
| 362 | + :return: A list of extra keys introduced by the compression process |
| 363 | + that are not expected by the model. |
| 364 | + """ |
| 365 | + |
| 366 | + unexpected_keys = set() |
| 367 | + |
| 368 | + # Identify unexpected keys from sparsity compression |
| 369 | + if ( |
| 370 | + self.sparsity_compressor |
| 371 | + and self.sparsity_config.format != CompressionFormat.dense.value |
| 372 | + ): |
| 373 | + sparse_targets = match_named_modules( |
| 374 | + model=model, |
| 375 | + targets=self.sparsity_config.targets, |
| 376 | + ignore=self.sparsity_config.ignore, |
| 377 | + ) |
| 378 | + unexpected_keys.update( |
| 379 | + merge_names(target_name, param) |
| 380 | + for target_name, _module in sparse_targets |
| 381 | + for param in self.sparsity_compressor.compression_param_names |
| 382 | + ) |
| 383 | + |
| 384 | + # Identify unexpected keys from quantization compression |
| 385 | + if self.quantization_compressor: |
| 386 | + for scheme in self.quantization_config.config_groups.values(): |
| 387 | + quant_targets = match_named_modules( |
| 388 | + model=model, |
| 389 | + targets=scheme.targets, |
| 390 | + ignore=self.quantization_config.ignore, |
| 391 | + ) |
| 392 | + for quant_compressor in self.quantization_compressor.values(): |
| 393 | + unexpected_keys.update( |
| 394 | + merge_names(target_name, param) |
| 395 | + for target_name, _module in quant_targets |
| 396 | + for param in quant_compressor.compression_param_names |
| 397 | + if param != "weight" |
| 398 | + ) |
| 399 | + |
| 400 | + return list(unexpected_keys) |
| 401 | + |
346 | 402 | # ----- model memory compression/decompression pathways ----- #
|
347 | 403 |
|
348 | 404 | def compress_model(self, model: Module):
|
|
0 commit comments