42
42
apply_quantization_config ,
43
43
load_pretrained_quantization_parameters ,
44
44
)
45
+ from compressed_tensors .quantization .lifecycle import expand_target_names
45
46
from compressed_tensors .quantization .utils import is_module_quantized
46
47
from compressed_tensors .transform import TransformConfig
47
48
from compressed_tensors .utils import (
59
60
fix_fsdp_module_name ,
60
61
is_compressed_tensors_config ,
61
62
)
62
- from compressed_tensors .utils .match import match_named_modules
63
63
from torch import Tensor
64
64
from torch .nn import Module
65
65
from tqdm import tqdm
@@ -342,15 +342,13 @@ def get_missing_module_keys(self, model: Module) -> List[str]:
342
342
self .sparsity_compressor
343
343
and self .sparsity_config .format != CompressionFormat .dense .value
344
344
):
345
- sparse_targets = match_named_modules (
345
+ sparse_targets = expand_target_names (
346
346
model = model ,
347
347
targets = self .sparsity_config .targets ,
348
348
ignore = self .sparsity_config .ignore ,
349
349
)
350
-
351
350
missing_keys .update (
352
- merge_names (target_name , "weight" )
353
- for target_name , _module in sparse_targets
351
+ merge_names (target , "weight" ) for target in sparse_targets
354
352
)
355
353
356
354
# Determine missing keys due to pack quantization
@@ -360,14 +358,13 @@ def get_missing_module_keys(self, model: Module) -> List[str]:
360
358
== CompressionFormat .pack_quantized .value
361
359
):
362
360
for scheme in self .quantization_config .config_groups .values ():
363
- quant_targets = match_named_modules (
361
+ quant_targets = expand_target_names (
364
362
model = model ,
365
363
targets = scheme .targets ,
366
364
ignore = self .quantization_config .ignore ,
367
365
)
368
366
missing_keys .update (
369
- merge_names (target_name , "weight" )
370
- for target_name , _module in quant_targets
367
+ merge_names (target , "weight" ) for target in quant_targets
371
368
)
372
369
373
370
return list (missing_keys )
@@ -398,29 +395,29 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
398
395
self .sparsity_compressor
399
396
and self .sparsity_config .format != CompressionFormat .dense .value
400
397
):
401
- sparse_targets = match_named_modules (
398
+ sparse_targets : Set [ str ] = expand_target_names (
402
399
model = model ,
403
400
targets = self .sparsity_config .targets ,
404
401
ignore = self .sparsity_config .ignore ,
405
402
)
406
403
unexpected_keys .update (
407
- merge_names (target_name , param )
408
- for target_name , _module in sparse_targets
404
+ merge_names (target , param )
405
+ for target in sparse_targets
409
406
for param in self .sparsity_compressor .compression_param_names
410
407
)
411
408
412
409
# Identify unexpected keys from quantization compression
413
410
if self .quantization_compressor :
414
411
for scheme in self .quantization_config .config_groups .values ():
415
- quant_targets = match_named_modules (
412
+ quant_targets : Set [ str ] = expand_target_names (
416
413
model = model ,
417
414
targets = scheme .targets ,
418
415
ignore = self .quantization_config .ignore ,
419
416
)
420
417
for quant_compressor in self .quantization_compressor .values ():
421
418
unexpected_keys .update (
422
- merge_names (target_name , param )
423
- for target_name , _module in quant_targets
419
+ merge_names (target , param )
420
+ for target in quant_targets
424
421
for param in quant_compressor .compression_param_names
425
422
if param != "weight"
426
423
)
@@ -437,79 +434,73 @@ def compress_model(self, model: Module):
437
434
:param model: model containing parameters to compress
438
435
"""
439
436
module_to_scheme = map_module_to_scheme (model )
440
- sparse_compression_targets = [
441
- module_name
442
- for module_name , _module in match_named_modules (
443
- model = model ,
444
- targets = self .sparsity_config .targets if self .sparsity_config else [],
445
- ignore = self .sparsity_config .ignore if self .sparsity_config else [],
446
- )
447
- ]
448
- for prefix , module in tqdm (
449
- match_named_modules (
450
- model ,
451
- [* sparse_compression_targets , * module_to_scheme .keys ()],
452
- warn_on_fail = True ,
453
- ),
454
- desc = "Compressing model" ,
455
- ):
456
- module_device = get_execution_device (module )
457
- is_meta = module_device .type == "meta"
458
-
459
- exec_device = "meta" if is_meta else "cpu"
460
- onloading_device = "meta" if is_meta else module_device
461
-
462
- # in the future, support compression on same device
463
- with align_module_device (module , execution_device = exec_device ):
464
- state_dict = {
465
- f"{ prefix } .{ name } " : param
466
- for name , param in module .named_parameters (recurse = False )
467
- }
468
-
469
- # quantization first
470
- if prefix in module_to_scheme :
471
- if (
472
- not hasattr (module .quantization_scheme , "format" )
473
- or module .quantization_scheme .format is None
474
- ):
475
- if len (self .compression_formats ) > 1 :
476
- raise ValueError (
477
- "Applying multiple compressors without defining "
478
- "per module formats is not supported "
479
- )
480
- format = self .compression_formats [0 ]
481
- else :
482
- format = module .quantization_scheme .format
483
-
484
- quant_compressor = self .quantization_compressor .get (format )
485
- state_dict = quant_compressor .compress (
486
- state_dict ,
487
- names_to_scheme = module_to_scheme ,
488
- show_progress = False ,
489
- compression_device = exec_device ,
490
- )
437
+ sparse_compression_targets : Set [str ] = expand_target_names (
438
+ model = model ,
439
+ targets = self .sparsity_config .targets if self .sparsity_config else [],
440
+ ignore = self .sparsity_config .ignore if self .sparsity_config else [],
441
+ )
491
442
492
- # sparsity second
493
- if prefix in sparse_compression_targets :
494
- state_dict = self .sparsity_compressor .compress (
495
- state_dict ,
496
- compression_targets = sparse_compression_targets ,
497
- show_progress = False ,
498
- )
443
+ for prefix , module in tqdm (model .named_modules (), desc = "Compressing model" ):
444
+
445
+ if prefix in module_to_scheme or prefix in sparse_compression_targets :
446
+ module_device = get_execution_device (module )
447
+ is_meta = module_device .type == "meta"
448
+
449
+ exec_device = "meta" if is_meta else "cpu"
450
+ onloading_device = "meta" if is_meta else module_device
451
+
452
+ # in the future, support compression on same device
453
+ with align_module_device (module , execution_device = exec_device ):
454
+ state_dict = {
455
+ f"{ prefix } .{ name } " : param
456
+ for name , param in module .named_parameters (recurse = False )
457
+ }
458
+
459
+ # quantization first
460
+ if prefix in module_to_scheme :
461
+ if (
462
+ not hasattr (module .quantization_scheme , "format" )
463
+ or module .quantization_scheme .format is None
464
+ ):
465
+ if len (self .compression_formats ) > 1 :
466
+ raise ValueError (
467
+ "Applying multiple compressors without defining "
468
+ "per module formats is not supported "
469
+ )
470
+ format = self .compression_formats [0 ]
471
+ else :
472
+ format = module .quantization_scheme .format
473
+
474
+ quant_compressor = self .quantization_compressor .get (format )
475
+ state_dict = quant_compressor .compress (
476
+ state_dict ,
477
+ names_to_scheme = module_to_scheme ,
478
+ show_progress = False ,
479
+ compression_device = exec_device ,
480
+ )
481
+
482
+ # sparsity second
483
+ if prefix in sparse_compression_targets :
484
+ state_dict = self .sparsity_compressor .compress (
485
+ state_dict ,
486
+ compression_targets = sparse_compression_targets ,
487
+ show_progress = False ,
488
+ )
499
489
500
- # remove any existing parameters
501
- offload_device = get_offloaded_device (module )
502
- for name , _ in list (module .named_parameters (recurse = False )):
503
- delete_offload_parameter (module , name )
490
+ # remove any existing parameters
491
+ offload_device = get_offloaded_device (module )
492
+ for name , _ in list (module .named_parameters (recurse = False )):
493
+ delete_offload_parameter (module , name )
504
494
505
- # replace with compressed parameters
506
- for name , value in state_dict .items ():
507
- name = name .removeprefix (f"{ prefix } ." )
508
- value = value .to (onloading_device )
509
- param = torch .nn .Parameter (value , requires_grad = False )
510
- register_offload_parameter (module , name , param , offload_device )
495
+ # replace with compressed parameters
496
+ for name , value in state_dict .items ():
497
+ name = name .removeprefix (f"{ prefix } ." )
498
+ value = value .to (onloading_device )
499
+ param = torch .nn .Parameter (value , requires_grad = False )
500
+ register_offload_parameter (module , name , param , offload_device )
501
+
502
+ module .quantization_status = QuantizationStatus .COMPRESSED
511
503
512
- module .quantization_status = QuantizationStatus .COMPRESSED
513
504
# TODO: consider sparse compression to also be compression
514
505
if (
515
506
self .quantization_config is not None
@@ -525,75 +516,67 @@ def decompress_model(self, model: Module):
525
516
:param model: model containing parameters to compress
526
517
"""
527
518
module_to_scheme = map_module_to_scheme (model )
528
- sparse_compression_targets = [
529
- module_name
530
- for module_name , _module in match_named_modules (
531
- model = model ,
532
- targets = self .sparsity_config .targets if self .sparsity_config else [],
533
- ignore = self .sparsity_config .ignore if self .sparsity_config else [],
534
- )
535
- ]
536
-
537
- for prefix , module in tqdm (
538
- match_named_modules (
539
- model ,
540
- [* sparse_compression_targets , * module_to_scheme .keys ()],
541
- warn_on_fail = True ,
542
- ),
543
- desc = "Decompressing model" ,
544
- ):
545
- # in the future, support decompression on same device
546
- with align_module_device (module , execution_device = "cpu" ):
547
- state_dict = {
548
- f"{ prefix } .{ name } " : param
549
- for name , param in module .named_parameters (recurse = False )
550
- }
551
-
552
- # sparsity first
553
- if prefix in sparse_compression_targets :
554
- # sparse_compression_targets are automatically inferred by this fn
555
- generator = self .sparsity_compressor .decompress_from_state_dict (
556
- state_dict ,
557
- )
558
- # generates (param_path, param_val)
559
- # of compressed and unused params
560
- state_dict = {key : value for key , value in generator }
561
-
562
- # quantization second
563
- if prefix in module_to_scheme :
564
- if (
565
- not hasattr (module .quantization_scheme , "format" )
566
- or module .quantization_scheme .format is None
567
- ):
568
- if len (self .compression_formats ) > 1 :
569
- raise ValueError (
570
- "Applying multiple compressors without defining "
571
- "per module formats is not supported "
572
- )
573
- format = self .compression_formats [0 ]
574
- else :
575
- format = module .quantization_scheme .format
576
- quant_compressor = self .quantization_compressor .get (format )
577
- state_dict = quant_compressor .decompress_module_from_state_dict (
578
- prefix ,
579
- state_dict ,
580
- scheme = module_to_scheme [prefix ],
581
- )
519
+ sparse_compression_targets : Set [str ] = expand_target_names (
520
+ model = model ,
521
+ targets = self .sparsity_config .targets if self .sparsity_config else [],
522
+ ignore = self .sparsity_config .ignore if self .sparsity_config else [],
523
+ )
524
+
525
+ for prefix , module in tqdm (model .named_modules (), desc = "Decompressing model" ):
526
+ if prefix in module_to_scheme or prefix in sparse_compression_targets :
527
+ # in the future, support decompression on same device
528
+ with align_module_device (module , execution_device = "cpu" ):
529
+ state_dict = {
530
+ f"{ prefix } .{ name } " : param
531
+ for name , param in module .named_parameters (recurse = False )
532
+ }
533
+
534
+ # sparsity first
535
+ if prefix in sparse_compression_targets :
536
+ # sparse_compression_targets are automatically inferred by this fn
537
+ generator = self .sparsity_compressor .decompress_from_state_dict (
538
+ state_dict ,
539
+ )
540
+ # generates (param_path, param_val)
541
+ # of compressed and unused params
542
+ state_dict = {key : value for key , value in generator }
543
+
544
+ # quantization second
545
+ if prefix in module_to_scheme :
546
+
547
+ if (
548
+ not hasattr (module .quantization_scheme , "format" )
549
+ or module .quantization_scheme .format is None
550
+ ):
551
+ if len (self .compression_formats ) > 1 :
552
+ raise ValueError (
553
+ "Applying multiple compressors without defining "
554
+ "per module formats is not supported "
555
+ )
556
+ format = self .compression_formats [0 ]
557
+ else :
558
+ format = module .quantization_scheme .format
559
+ quant_compressor = self .quantization_compressor .get (format )
560
+ state_dict = quant_compressor .decompress_module_from_state_dict (
561
+ prefix ,
562
+ state_dict ,
563
+ scheme = module_to_scheme [prefix ],
564
+ )
582
565
583
- # remove any existing parameters
584
- exec_device = get_execution_device (module )
585
- offload_device = get_offloaded_device (module )
586
- for name , _ in list (module .named_parameters (recurse = False )):
587
- delete_offload_parameter (module , name )
566
+ # remove any existing parameters
567
+ exec_device = get_execution_device (module )
568
+ offload_device = get_offloaded_device (module )
569
+ for name , _ in list (module .named_parameters (recurse = False )):
570
+ delete_offload_parameter (module , name )
588
571
589
- # replace with decompressed parameters
590
- for name , value in state_dict .items ():
591
- name = name .removeprefix (f"{ prefix } ." )
592
- value = value .to (exec_device )
593
- param = torch .nn .Parameter (value , requires_grad = False )
594
- register_offload_parameter (module , name , param , offload_device )
572
+ # replace with decompressed parameters
573
+ for name , value in state_dict .items ():
574
+ name = name .removeprefix (f"{ prefix } ." )
575
+ value = value .to (exec_device )
576
+ param = torch .nn .Parameter (value , requires_grad = False )
577
+ register_offload_parameter (module , name , param , offload_device )
595
578
596
- module .quantization_status = QuantizationStatus .FROZEN
579
+ module .quantization_status = QuantizationStatus .FROZEN
597
580
598
581
# ----- state dict compression pathways ----- #
599
582
@@ -631,14 +614,11 @@ def compress(
631
614
)
632
615
633
616
if self .sparsity_compressor is not None :
634
- sparse_compression_targets : Set [str ] = {
635
- module_name
636
- for module_name , _module in match_named_modules (
637
- model = model ,
638
- targets = self .sparsity_config .targets ,
639
- ignore = self .sparsity_config .ignore ,
640
- )
641
- }
617
+ sparse_compression_targets : Set [str ] = expand_target_names (
618
+ model = model ,
619
+ targets = self .sparsity_config .targets ,
620
+ ignore = self .sparsity_config .ignore ,
621
+ )
642
622
state_dict = self .sparsity_compressor .compress (
643
623
state_dict ,
644
624
compression_targets = sparse_compression_targets ,
@@ -703,6 +683,7 @@ def decompress(self, model_path: str, model: Module):
703
683
with override_quantization_status (
704
684
self .quantization_config , QuantizationStatus .FROZEN
705
685
):
686
+
706
687
names_to_scheme = apply_quantization_config (
707
688
model , self .quantization_config
708
689
)
0 commit comments