|
50 | 50 | get_offloaded_device,
|
51 | 51 | get_safetensors_folder,
|
52 | 52 | has_offloaded_params,
|
53 |
| - merge_names, |
54 | 53 | register_offload_parameter,
|
55 | 54 | update_parameter_data,
|
56 | 55 | )
|
@@ -321,112 +320,6 @@ def __init__(
|
321 | 320 | format, config=quantization_config
|
322 | 321 | )
|
323 | 322 |
|
324 |
| - # ----- used by hf quantizer ----- # |
325 |
| - |
326 |
| - def get_missing_module_keys(self, model: Module) -> List[str]: |
327 |
| - """ |
328 |
| - Identifies the expected missing weight keys in the compressed state_dict. |
329 |
| -
|
330 |
| - When a model undergoes sparsity or quantization compression, certain |
331 |
| - weight tensors may be absent from the checkpoint by virtue of compression. |
332 |
| - This function determines which weight keys are missing based on the |
333 |
| - applied compression techniques. |
334 |
| -
|
335 |
| - :param model: The PyTorch model to check for missing keys. |
336 |
| - :return: A list of missing keys expected in the compressed state_dict. |
337 |
| - """ |
338 |
| - missing_keys = set() |
339 |
| - |
340 |
| - # Determine missing keys due to sparsity compression |
341 |
| - if ( |
342 |
| - self.sparsity_compressor |
343 |
| - and self.sparsity_config.format != CompressionFormat.dense.value |
344 |
| - ): |
345 |
| - sparse_targets = match_named_modules( |
346 |
| - model=model, |
347 |
| - targets=self.sparsity_config.targets, |
348 |
| - ignore=self.sparsity_config.ignore, |
349 |
| - ) |
350 |
| - |
351 |
| - missing_keys.update( |
352 |
| - merge_names(target_name, "weight") |
353 |
| - for target_name, _module in sparse_targets |
354 |
| - ) |
355 |
| - |
356 |
| - # Determine missing keys due to pack quantization |
357 |
| - if ( |
358 |
| - self.quantization_compressor |
359 |
| - and self.quantization_config.format |
360 |
| - == CompressionFormat.pack_quantized.value |
361 |
| - ): |
362 |
| - for scheme in self.quantization_config.config_groups.values(): |
363 |
| - quant_targets = match_named_modules( |
364 |
| - model=model, |
365 |
| - targets=scheme.targets, |
366 |
| - ignore=self.quantization_config.ignore, |
367 |
| - ) |
368 |
| - missing_keys.update( |
369 |
| - merge_names(target_name, "weight") |
370 |
| - for target_name, _module in quant_targets |
371 |
| - ) |
372 |
| - |
373 |
| - return list(missing_keys) |
374 |
| - |
375 |
| - def get_unexpected_file_keys(self, model: Module) -> List[str]: |
376 |
| - """ |
377 |
| - Identifies extra keys introduced by the compression process in the |
378 |
| - compressed state_dict that are not expected by the model graph. |
379 |
| -
|
380 |
| - During sparsity or quantization compression, additional metadata or |
381 |
| - auxiliary parameters may be stored in the checkpoint, which do not |
382 |
| - correspond to any parameter in the original model. These keys are |
383 |
| - typically introduced to support the reconstruction of compressed weights. |
384 |
| -
|
385 |
| - For example, Sparse24Bitmask compression may introduce keys such as |
386 |
| - 'compressed', 'bitmask', and 'shape' in the checkpoint, which are |
387 |
| - not part of the original model parameters. |
388 |
| -
|
389 |
| - :param model: The PyTorch model to check for unexpected keys. |
390 |
| - :return: A list of extra keys introduced by the compression process |
391 |
| - that are not expected by the model. |
392 |
| - """ |
393 |
| - |
394 |
| - unexpected_keys = set() |
395 |
| - |
396 |
| - # Identify unexpected keys from sparsity compression |
397 |
| - if ( |
398 |
| - self.sparsity_compressor |
399 |
| - and self.sparsity_config.format != CompressionFormat.dense.value |
400 |
| - ): |
401 |
| - sparse_targets = match_named_modules( |
402 |
| - model=model, |
403 |
| - targets=self.sparsity_config.targets, |
404 |
| - ignore=self.sparsity_config.ignore, |
405 |
| - ) |
406 |
| - unexpected_keys.update( |
407 |
| - merge_names(target_name, param) |
408 |
| - for target_name, _module in sparse_targets |
409 |
| - for param in self.sparsity_compressor.compression_param_names |
410 |
| - ) |
411 |
| - |
412 |
| - # Identify unexpected keys from quantization compression |
413 |
| - if self.quantization_compressor: |
414 |
| - for scheme in self.quantization_config.config_groups.values(): |
415 |
| - quant_targets = match_named_modules( |
416 |
| - model=model, |
417 |
| - targets=scheme.targets, |
418 |
| - ignore=self.quantization_config.ignore, |
419 |
| - ) |
420 |
| - for quant_compressor in self.quantization_compressor.values(): |
421 |
| - unexpected_keys.update( |
422 |
| - merge_names(target_name, param) |
423 |
| - for target_name, _module in quant_targets |
424 |
| - for param in quant_compressor.compression_param_names |
425 |
| - if param != "weight" |
426 |
| - ) |
427 |
| - |
428 |
| - return list(unexpected_keys) |
429 |
| - |
430 | 323 | # ----- model memory compression/decompression pathways ----- #
|
431 | 324 |
|
432 | 325 | def compress_model(self, model: Module):
|
|
0 commit comments