42
42
apply_quantization_config ,
43
43
load_pretrained_quantization_parameters ,
44
44
)
45
- from compressed_tensors .quantization .lifecycle import expand_target_names
46
45
from compressed_tensors .quantization .utils import is_module_quantized
47
46
from compressed_tensors .transform import TransformConfig
48
47
from compressed_tensors .utils import (
60
59
fix_fsdp_module_name ,
61
60
is_compressed_tensors_config ,
62
61
)
62
+ from compressed_tensors .utils .match import match_named_modules
63
63
from torch import Tensor
64
64
from torch .nn import Module
65
65
from tqdm import tqdm
@@ -342,13 +342,15 @@ def get_missing_module_keys(self, model: Module) -> List[str]:
342
342
self .sparsity_compressor
343
343
and self .sparsity_config .format != CompressionFormat .dense .value
344
344
):
345
- sparse_targets = expand_target_names (
345
+ sparse_targets = match_named_modules (
346
346
model = model ,
347
347
targets = self .sparsity_config .targets ,
348
348
ignore = self .sparsity_config .ignore ,
349
349
)
350
+
350
351
missing_keys .update (
351
- merge_names (target , "weight" ) for target in sparse_targets
352
+ merge_names (target_name , "weight" )
353
+ for target_name , _module in sparse_targets
352
354
)
353
355
354
356
# Determine missing keys due to pack quantization
@@ -358,13 +360,14 @@ def get_missing_module_keys(self, model: Module) -> List[str]:
358
360
== CompressionFormat .pack_quantized .value
359
361
):
360
362
for scheme in self .quantization_config .config_groups .values ():
361
- quant_targets = expand_target_names (
363
+ quant_targets = match_named_modules (
362
364
model = model ,
363
365
targets = scheme .targets ,
364
366
ignore = self .quantization_config .ignore ,
365
367
)
366
368
missing_keys .update (
367
- merge_names (target , "weight" ) for target in quant_targets
369
+ merge_names (target_name , "weight" )
370
+ for target_name , _module in quant_targets
368
371
)
369
372
370
373
return list (missing_keys )
@@ -395,29 +398,29 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
395
398
self .sparsity_compressor
396
399
and self .sparsity_config .format != CompressionFormat .dense .value
397
400
):
398
- sparse_targets : Set [ str ] = expand_target_names (
401
+ sparse_targets = match_named_modules (
399
402
model = model ,
400
403
targets = self .sparsity_config .targets ,
401
404
ignore = self .sparsity_config .ignore ,
402
405
)
403
406
unexpected_keys .update (
404
- merge_names (target , param )
405
- for target in sparse_targets
407
+ merge_names (target_name , param )
408
+ for target_name , _module in sparse_targets
406
409
for param in self .sparsity_compressor .compression_param_names
407
410
)
408
411
409
412
# Identify unexpected keys from quantization compression
410
413
if self .quantization_compressor :
411
414
for scheme in self .quantization_config .config_groups .values ():
412
- quant_targets : Set [ str ] = expand_target_names (
415
+ quant_targets = match_named_modules (
413
416
model = model ,
414
417
targets = scheme .targets ,
415
418
ignore = self .quantization_config .ignore ,
416
419
)
417
420
for quant_compressor in self .quantization_compressor .values ():
418
421
unexpected_keys .update (
419
- merge_names (target , param )
420
- for target in quant_targets
422
+ merge_names (target_name , param )
423
+ for target_name , _module in quant_targets
421
424
for param in quant_compressor .compression_param_names
422
425
if param != "weight"
423
426
)
@@ -434,73 +437,79 @@ def compress_model(self, model: Module):
434
437
:param model: model containing parameters to compress
435
438
"""
436
439
module_to_scheme = map_module_to_scheme (model )
437
- sparse_compression_targets : Set [str ] = expand_target_names (
438
- model = model ,
439
- targets = self .sparsity_config .targets if self .sparsity_config else [],
440
- ignore = self .sparsity_config .ignore if self .sparsity_config else [],
441
- )
442
-
443
- for prefix , module in tqdm (model .named_modules (), desc = "Compressing model" ):
444
-
445
- if prefix in module_to_scheme or prefix in sparse_compression_targets :
446
- module_device = get_execution_device (module )
447
- is_meta = module_device .type == "meta"
448
-
449
- exec_device = "meta" if is_meta else "cpu"
450
- onloading_device = "meta" if is_meta else module_device
451
-
452
- # in the future, support compression on same device
453
- with align_module_device (module , execution_device = exec_device ):
454
- state_dict = {
455
- f"{ prefix } .{ name } " : param
456
- for name , param in module .named_parameters (recurse = False )
457
- }
458
-
459
- # quantization first
460
- if prefix in module_to_scheme :
461
- if (
462
- not hasattr (module .quantization_scheme , "format" )
463
- or module .quantization_scheme .format is None
464
- ):
465
- if len (self .compression_formats ) > 1 :
466
- raise ValueError (
467
- "Applying multiple compressors without defining "
468
- "per module formats is not supported "
469
- )
470
- format = self .compression_formats [0 ]
471
- else :
472
- format = module .quantization_scheme .format
473
-
474
- quant_compressor = self .quantization_compressor .get (format )
475
- state_dict = quant_compressor .compress (
476
- state_dict ,
477
- names_to_scheme = module_to_scheme ,
478
- show_progress = False ,
479
- compression_device = exec_device ,
480
- )
481
-
482
- # sparsity second
483
- if prefix in sparse_compression_targets :
484
- state_dict = self .sparsity_compressor .compress (
485
- state_dict ,
486
- compression_targets = sparse_compression_targets ,
487
- show_progress = False ,
488
- )
440
+ sparse_compression_targets = [
441
+ module_name
442
+ for module_name , _module in match_named_modules (
443
+ model = model ,
444
+ targets = self .sparsity_config .targets if self .sparsity_config else [],
445
+ ignore = self .sparsity_config .ignore if self .sparsity_config else [],
446
+ )
447
+ ]
448
+ for prefix , module in tqdm (
449
+ match_named_modules (
450
+ model ,
451
+ [* sparse_compression_targets , * module_to_scheme .keys ()],
452
+ warn_on_fail = True ,
453
+ ),
454
+ desc = "Compressing model" ,
455
+ ):
456
+ module_device = get_execution_device (module )
457
+ is_meta = module_device .type == "meta"
458
+
459
+ exec_device = "meta" if is_meta else "cpu"
460
+ onloading_device = "meta" if is_meta else module_device
461
+
462
+ # in the future, support compression on same device
463
+ with align_module_device (module , execution_device = exec_device ):
464
+ state_dict = {
465
+ f"{ prefix } .{ name } " : param
466
+ for name , param in module .named_parameters (recurse = False )
467
+ }
468
+
469
+ # quantization first
470
+ if prefix in module_to_scheme :
471
+ if (
472
+ not hasattr (module .quantization_scheme , "format" )
473
+ or module .quantization_scheme .format is None
474
+ ):
475
+ if len (self .compression_formats ) > 1 :
476
+ raise ValueError (
477
+ "Applying multiple compressors without defining "
478
+ "per module formats is not supported "
479
+ )
480
+ format = self .compression_formats [0 ]
481
+ else :
482
+ format = module .quantization_scheme .format
483
+
484
+ quant_compressor = self .quantization_compressor .get (format )
485
+ state_dict = quant_compressor .compress (
486
+ state_dict ,
487
+ names_to_scheme = module_to_scheme ,
488
+ show_progress = False ,
489
+ compression_device = exec_device ,
490
+ )
489
491
490
- # remove any existing parameters
491
- offload_device = get_offloaded_device (module )
492
- for name , _ in list (module .named_parameters (recurse = False )):
493
- delete_offload_parameter (module , name )
492
+ # sparsity second
493
+ if prefix in sparse_compression_targets :
494
+ state_dict = self .sparsity_compressor .compress (
495
+ state_dict ,
496
+ compression_targets = sparse_compression_targets ,
497
+ show_progress = False ,
498
+ )
494
499
495
- # replace with compressed parameters
496
- for name , value in state_dict .items ():
497
- name = name .removeprefix (f"{ prefix } ." )
498
- value = value .to (onloading_device )
499
- param = torch .nn .Parameter (value , requires_grad = False )
500
- register_offload_parameter (module , name , param , offload_device )
500
+ # remove any existing parameters
501
+ offload_device = get_offloaded_device (module )
502
+ for name , _ in list (module .named_parameters (recurse = False )):
503
+ delete_offload_parameter (module , name )
501
504
502
- module .quantization_status = QuantizationStatus .COMPRESSED
505
+ # replace with compressed parameters
506
+ for name , value in state_dict .items ():
507
+ name = name .removeprefix (f"{ prefix } ." )
508
+ value = value .to (onloading_device )
509
+ param = torch .nn .Parameter (value , requires_grad = False )
510
+ register_offload_parameter (module , name , param , offload_device )
503
511
512
+ module .quantization_status = QuantizationStatus .COMPRESSED
504
513
# TODO: consider sparse compression to also be compression
505
514
if (
506
515
self .quantization_config is not None
@@ -516,67 +525,75 @@ def decompress_model(self, model: Module):
516
525
:param model: model containing parameters to compress
517
526
"""
518
527
module_to_scheme = map_module_to_scheme (model )
519
- sparse_compression_targets : Set [str ] = expand_target_names (
520
- model = model ,
521
- targets = self .sparsity_config .targets if self .sparsity_config else [],
522
- ignore = self .sparsity_config .ignore if self .sparsity_config else [],
523
- )
524
-
525
- for prefix , module in tqdm (model .named_modules (), desc = "Decompressing model" ):
526
- if prefix in module_to_scheme or prefix in sparse_compression_targets :
527
- # in the future, support decompression on same device
528
- with align_module_device (module , execution_device = "cpu" ):
529
- state_dict = {
530
- f"{ prefix } .{ name } " : param
531
- for name , param in module .named_parameters (recurse = False )
532
- }
533
-
534
- # sparsity first
535
- if prefix in sparse_compression_targets :
536
- # sparse_compression_targets are automatically inferred by this fn
537
- generator = self .sparsity_compressor .decompress_from_state_dict (
538
- state_dict ,
539
- )
540
- # generates (param_path, param_val)
541
- # of compressed and unused params
542
- state_dict = {key : value for key , value in generator }
543
-
544
- # quantization second
545
- if prefix in module_to_scheme :
546
-
547
- if (
548
- not hasattr (module .quantization_scheme , "format" )
549
- or module .quantization_scheme .format is None
550
- ):
551
- if len (self .compression_formats ) > 1 :
552
- raise ValueError (
553
- "Applying multiple compressors without defining "
554
- "per module formats is not supported "
555
- )
556
- format = self .compression_formats [0 ]
557
- else :
558
- format = module .quantization_scheme .format
559
- quant_compressor = self .quantization_compressor .get (format )
560
- state_dict = quant_compressor .decompress_module_from_state_dict (
561
- prefix ,
562
- state_dict ,
563
- scheme = module_to_scheme [prefix ],
564
- )
528
+ sparse_compression_targets = [
529
+ module_name
530
+ for module_name , _module in match_named_modules (
531
+ model = model ,
532
+ targets = self .sparsity_config .targets if self .sparsity_config else [],
533
+ ignore = self .sparsity_config .ignore if self .sparsity_config else [],
534
+ )
535
+ ]
536
+
537
+ for prefix , module in tqdm (
538
+ match_named_modules (
539
+ model ,
540
+ [* sparse_compression_targets , * module_to_scheme .keys ()],
541
+ warn_on_fail = True ,
542
+ ),
543
+ desc = "Decompressing model" ,
544
+ ):
545
+ # in the future, support decompression on same device
546
+ with align_module_device (module , execution_device = "cpu" ):
547
+ state_dict = {
548
+ f"{ prefix } .{ name } " : param
549
+ for name , param in module .named_parameters (recurse = False )
550
+ }
551
+
552
+ # sparsity first
553
+ if prefix in sparse_compression_targets :
554
+ # sparse_compression_targets are automatically inferred by this fn
555
+ generator = self .sparsity_compressor .decompress_from_state_dict (
556
+ state_dict ,
557
+ )
558
+ # generates (param_path, param_val)
559
+ # of compressed and unused params
560
+ state_dict = {key : value for key , value in generator }
561
+
562
+ # quantization second
563
+ if prefix in module_to_scheme :
564
+ if (
565
+ not hasattr (module .quantization_scheme , "format" )
566
+ or module .quantization_scheme .format is None
567
+ ):
568
+ if len (self .compression_formats ) > 1 :
569
+ raise ValueError (
570
+ "Applying multiple compressors without defining "
571
+ "per module formats is not supported "
572
+ )
573
+ format = self .compression_formats [0 ]
574
+ else :
575
+ format = module .quantization_scheme .format
576
+ quant_compressor = self .quantization_compressor .get (format )
577
+ state_dict = quant_compressor .decompress_module_from_state_dict (
578
+ prefix ,
579
+ state_dict ,
580
+ scheme = module_to_scheme [prefix ],
581
+ )
565
582
566
- # remove any existing parameters
567
- exec_device = get_execution_device (module )
568
- offload_device = get_offloaded_device (module )
569
- for name , _ in list (module .named_parameters (recurse = False )):
570
- delete_offload_parameter (module , name )
583
+ # remove any existing parameters
584
+ exec_device = get_execution_device (module )
585
+ offload_device = get_offloaded_device (module )
586
+ for name , _ in list (module .named_parameters (recurse = False )):
587
+ delete_offload_parameter (module , name )
571
588
572
- # replace with decompressed parameters
573
- for name , value in state_dict .items ():
574
- name = name .removeprefix (f"{ prefix } ." )
575
- value = value .to (exec_device )
576
- param = torch .nn .Parameter (value , requires_grad = False )
577
- register_offload_parameter (module , name , param , offload_device )
589
+ # replace with decompressed parameters
590
+ for name , value in state_dict .items ():
591
+ name = name .removeprefix (f"{ prefix } ." )
592
+ value = value .to (exec_device )
593
+ param = torch .nn .Parameter (value , requires_grad = False )
594
+ register_offload_parameter (module , name , param , offload_device )
578
595
579
- module .quantization_status = QuantizationStatus .FROZEN
596
+ module .quantization_status = QuantizationStatus .FROZEN
580
597
581
598
# ----- state dict compression pathways ----- #
582
599
@@ -614,11 +631,14 @@ def compress(
614
631
)
615
632
616
633
if self .sparsity_compressor is not None :
617
- sparse_compression_targets : Set [str ] = expand_target_names (
618
- model = model ,
619
- targets = self .sparsity_config .targets ,
620
- ignore = self .sparsity_config .ignore ,
621
- )
634
+ sparse_compression_targets : Set [str ] = {
635
+ module_name
636
+ for module_name , _module in match_named_modules (
637
+ model = model ,
638
+ targets = self .sparsity_config .targets ,
639
+ ignore = self .sparsity_config .ignore ,
640
+ )
641
+ }
622
642
state_dict = self .sparsity_compressor .compress (
623
643
state_dict ,
624
644
compression_targets = sparse_compression_targets ,
@@ -683,7 +703,6 @@ def decompress(self, model_path: str, model: Module):
683
703
with override_quantization_status (
684
704
self .quantization_config , QuantizationStatus .FROZEN
685
705
):
686
-
687
706
names_to_scheme = apply_quantization_config (
688
707
model , self .quantization_config
689
708
)
0 commit comments