41
41
apply_quantization_config ,
42
42
load_pretrained_quantization_parameters ,
43
43
)
44
- from compressed_tensors .quantization .lifecycle import expand_target_names
45
44
from compressed_tensors .quantization .utils import is_module_quantized
45
+ from compressed_tensors .utils .match import match_named_modules
46
46
from compressed_tensors .utils import (
47
47
align_module_device ,
48
48
delete_offload_parameter ,
@@ -292,13 +292,15 @@ def get_missing_module_keys(self, model: Module) -> List[str]:
292
292
self .sparsity_compressor
293
293
and self .sparsity_config .format != CompressionFormat .dense .value
294
294
):
295
- sparse_targets = expand_target_names (
295
+ sparse_targets = match_named_modules (
296
296
model = model ,
297
297
targets = self .sparsity_config .targets ,
298
298
ignore = self .sparsity_config .ignore ,
299
299
)
300
+
300
301
missing_keys .update (
301
- merge_names (target , "weight" ) for target in sparse_targets
302
+ merge_names (target_name , "weight" )
303
+ for target_name , _module in sparse_targets
302
304
)
303
305
304
306
# Determine missing keys due to pack quantization
@@ -308,13 +310,14 @@ def get_missing_module_keys(self, model: Module) -> List[str]:
308
310
== CompressionFormat .pack_quantized .value
309
311
):
310
312
for scheme in self .quantization_config .config_groups .values ():
311
- quant_targets = expand_target_names (
313
+ quant_targets = match_named_modules (
312
314
model = model ,
313
315
targets = scheme .targets ,
314
316
ignore = self .quantization_config .ignore ,
315
317
)
316
318
missing_keys .update (
317
- merge_names (target , "weight" ) for target in quant_targets
319
+ merge_names (target_name , "weight" )
320
+ for target_name , _module in quant_targets
318
321
)
319
322
320
323
return list (missing_keys )
@@ -345,28 +348,28 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
345
348
self .sparsity_compressor
346
349
and self .sparsity_config .format != CompressionFormat .dense .value
347
350
):
348
- sparse_targets : Set [ str ] = expand_target_names (
351
+ sparse_targets = match_named_modules (
349
352
model = model ,
350
353
targets = self .sparsity_config .targets ,
351
354
ignore = self .sparsity_config .ignore ,
352
355
)
353
356
unexpected_keys .update (
354
- merge_names (target , param )
355
- for target in sparse_targets
357
+ merge_names (target_name , param )
358
+ for target_name , _module in sparse_targets
356
359
for param in self .sparsity_compressor .compression_param_names
357
360
)
358
361
359
362
# Identify unexpected keys from quantization compression
360
363
if self .quantization_compressor :
361
364
for scheme in self .quantization_config .config_groups .values ():
362
- quant_targets : Set [ str ] = expand_target_names (
365
+ quant_targets = match_named_modules (
363
366
model = model ,
364
367
targets = scheme .targets ,
365
368
ignore = self .quantization_config .ignore ,
366
369
)
367
370
unexpected_keys .update (
368
- merge_names (target , param )
369
- for target in quant_targets
371
+ merge_names (target_name , param )
372
+ for target_name , _module in quant_targets
370
373
for param in self .quantization_compressor .compression_param_names
371
374
if param != "weight"
372
375
)
@@ -383,58 +386,65 @@ def compress_model(self, model: Module):
383
386
:param model: model containing parameters to compress
384
387
"""
385
388
module_to_scheme = map_module_to_scheme (model )
386
- sparse_compression_targets : Set [str ] = expand_target_names (
387
- model = model ,
388
- targets = self .sparsity_config .targets if self .sparsity_config else [],
389
- ignore = self .sparsity_config .ignore if self .sparsity_config else [],
390
- )
391
-
392
- for prefix , module in tqdm (model .named_modules (), desc = "Compressing model" ):
393
-
394
- if prefix in module_to_scheme or prefix in sparse_compression_targets :
395
- module_device = get_execution_device (module )
396
- is_meta = module_device .type == "meta"
397
-
398
- exec_device = "meta" if is_meta else "cpu"
399
- onloading_device = "meta" if is_meta else module_device
400
-
401
- # in the future, support compression on same device
402
- with align_module_device (module , execution_device = exec_device ):
403
- state_dict = {
404
- f"{ prefix } .{ name } " : param
405
- for name , param in module .named_parameters (recurse = False )
406
- }
407
-
408
- # quantization first
409
- if prefix in module_to_scheme :
410
- state_dict = self .quantization_compressor .compress (
411
- state_dict ,
412
- names_to_scheme = module_to_scheme ,
413
- show_progress = False ,
414
- compression_device = exec_device ,
415
- )
389
+ sparse_compression_targets = [
390
+ module_name
391
+ for module_name , _module in match_named_modules (
392
+ model = model ,
393
+ targets = self .sparsity_config .targets if self .sparsity_config else [],
394
+ ignore = self .sparsity_config .ignore if self .sparsity_config else [],
395
+ )
396
+ ]
397
+ for prefix , module in tqdm (
398
+ match_named_modules (
399
+ model ,
400
+ [* sparse_compression_targets , * module_to_scheme .keys ()],
401
+ warn_on_fail = True ,
402
+ ),
403
+ desc = "Compressing model" ,
404
+ ):
405
+ module_device = get_execution_device (module )
406
+ is_meta = module_device .type == "meta"
407
+
408
+ exec_device = "meta" if is_meta else "cpu"
409
+ onloading_device = "meta" if is_meta else module_device
410
+
411
+ # in the future, support compression on same device
412
+ with align_module_device (module , execution_device = exec_device ):
413
+ state_dict = {
414
+ f"{ prefix } .{ name } " : param
415
+ for name , param in module .named_parameters (recurse = False )
416
+ }
417
+
418
+ # quantization first
419
+ if prefix in module_to_scheme :
420
+ state_dict = self .quantization_compressor .compress (
421
+ state_dict ,
422
+ names_to_scheme = module_to_scheme ,
423
+ show_progress = False ,
424
+ compression_device = exec_device ,
425
+ )
416
426
417
- # sparsity second
418
- if prefix in sparse_compression_targets :
419
- state_dict = self .sparsity_compressor .compress (
420
- state_dict ,
421
- compression_targets = sparse_compression_targets ,
422
- show_progress = False ,
423
- )
427
+ # sparsity second
428
+ if prefix in sparse_compression_targets :
429
+ state_dict = self .sparsity_compressor .compress (
430
+ state_dict ,
431
+ compression_targets = sparse_compression_targets ,
432
+ show_progress = False ,
433
+ )
424
434
425
- # remove any existing parameters
426
- offload_device = get_offloaded_device (module )
427
- for name , _ in list (module .named_parameters (recurse = False )):
428
- delete_offload_parameter (module , name )
435
+ # remove any existing parameters
436
+ offload_device = get_offloaded_device (module )
437
+ for name , _ in list (module .named_parameters (recurse = False )):
438
+ delete_offload_parameter (module , name )
429
439
430
- # replace with compressed parameters
431
- for name , value in state_dict .items ():
432
- name = name .removeprefix (f"{ prefix } ." )
433
- value = value .to (onloading_device )
434
- param = torch .nn .Parameter (value , requires_grad = False )
435
- register_offload_parameter (module , name , param , offload_device )
440
+ # replace with compressed parameters
441
+ for name , value in state_dict .items ():
442
+ name = name .removeprefix (f"{ prefix } ." )
443
+ value = value .to (onloading_device )
444
+ param = torch .nn .Parameter (value , requires_grad = False )
445
+ register_offload_parameter (module , name , param , offload_device )
436
446
437
- module .quantization_status = QuantizationStatus .COMPRESSED
447
+ module .quantization_status = QuantizationStatus .COMPRESSED
438
448
439
449
# TODO: consider sparse compression to also be compression
440
450
if (
@@ -451,55 +461,64 @@ def decompress_model(self, model: Module):
451
461
:param model: model containing parameters to compress
452
462
"""
453
463
module_to_scheme = map_module_to_scheme (model )
454
- sparse_compression_targets : Set [str ] = expand_target_names (
455
- model = model ,
456
- targets = self .sparsity_config .targets if self .sparsity_config else [],
457
- ignore = self .sparsity_config .ignore if self .sparsity_config else [],
458
- )
459
-
460
- for prefix , module in tqdm (model .named_modules (), desc = "Decompressing model" ):
461
- if prefix in module_to_scheme or prefix in sparse_compression_targets :
462
- # in the future, support decompression on same device
463
- with align_module_device (module , execution_device = "cpu" ):
464
- state_dict = {
465
- f"{ prefix } .{ name } " : param
466
- for name , param in module .named_parameters (recurse = False )
467
- }
468
-
469
- # sparsity first
470
- if prefix in sparse_compression_targets :
471
- # sparse_compression_targets are automatically inferred by this fn
472
- generator = self .sparsity_compressor .decompress_from_state_dict (
464
+ sparse_compression_targets = [
465
+ module_name
466
+ for module_name , _module in match_named_modules (
467
+ model = model ,
468
+ targets = self .sparsity_config .targets if self .sparsity_config else [],
469
+ ignore = self .sparsity_config .ignore if self .sparsity_config else [],
470
+ )
471
+ ]
472
+
473
+ for prefix , module in tqdm (
474
+ match_named_modules (
475
+ model ,
476
+ [* sparse_compression_targets , * module_to_scheme .keys ()],
477
+ warn_on_fail = True ,
478
+ ),
479
+ desc = "Decompressing model" ,
480
+ ):
481
+ # in the future, support decompression on same device
482
+ with align_module_device (module , execution_device = "cpu" ):
483
+ state_dict = {
484
+ f"{ prefix } .{ name } " : param
485
+ for name , param in module .named_parameters (recurse = False )
486
+ }
487
+
488
+ # sparsity first
489
+ if prefix in sparse_compression_targets :
490
+ # sparse_compression_targets are automatically inferred by this fn
491
+ generator = self .sparsity_compressor .decompress_from_state_dict (
492
+ state_dict ,
493
+ )
494
+ # generates (param_path, param_val)
495
+ # of compressed and unused params
496
+ state_dict = {key : value for key , value in generator }
497
+
498
+ # quantization second
499
+ if prefix in module_to_scheme :
500
+ state_dict = (
501
+ self .quantization_compressor .decompress_module_from_state_dict (
502
+ prefix ,
473
503
state_dict ,
504
+ scheme = module_to_scheme [prefix ],
474
505
)
475
- # generates (param_path, param_val)
476
- # of compressed and unused params
477
- state_dict = {key : value for key , value in generator }
478
-
479
- # quantization second
480
- if prefix in module_to_scheme :
481
- state_dict = (
482
- self .quantization_compressor .decompress_module_from_state_dict (
483
- prefix ,
484
- state_dict ,
485
- scheme = module_to_scheme [prefix ],
486
- )
487
- )
506
+ )
488
507
489
- # remove any existing parameters
490
- exec_device = get_execution_device (module )
491
- offload_device = get_offloaded_device (module )
492
- for name , _ in list (module .named_parameters (recurse = False )):
493
- delete_offload_parameter (module , name )
508
+ # remove any existing parameters
509
+ exec_device = get_execution_device (module )
510
+ offload_device = get_offloaded_device (module )
511
+ for name , _ in list (module .named_parameters (recurse = False )):
512
+ delete_offload_parameter (module , name )
494
513
495
- # replace with decompressed parameters
496
- for name , value in state_dict .items ():
497
- name = name .removeprefix (f"{ prefix } ." )
498
- value = value .to (exec_device )
499
- param = torch .nn .Parameter (value , requires_grad = False )
500
- register_offload_parameter (module , name , param , offload_device )
514
+ # replace with decompressed parameters
515
+ for name , value in state_dict .items ():
516
+ name = name .removeprefix (f"{ prefix } ." )
517
+ value = value .to (exec_device )
518
+ param = torch .nn .Parameter (value , requires_grad = False )
519
+ register_offload_parameter (module , name , param , offload_device )
501
520
502
- module .quantization_status = QuantizationStatus .FROZEN
521
+ module .quantization_status = QuantizationStatus .FROZEN
503
522
504
523
# ----- state dict compression pathways ----- #
505
524
@@ -535,11 +554,14 @@ def compress(
535
554
)
536
555
537
556
if self .sparsity_compressor is not None :
538
- sparse_compression_targets : Set [str ] = expand_target_names (
539
- model = model ,
540
- targets = self .sparsity_config .targets ,
541
- ignore = self .sparsity_config .ignore ,
542
- )
557
+ sparse_compression_targets : Set [str ] = {
558
+ module_name
559
+ for module_name , _module in match_named_modules (
560
+ model = model ,
561
+ targets = self .sparsity_config .targets ,
562
+ ignore = self .sparsity_config .ignore ,
563
+ )
564
+ }
543
565
state_dict = self .sparsity_compressor .compress (
544
566
state_dict ,
545
567
compression_targets = sparse_compression_targets ,
@@ -598,7 +620,6 @@ def decompress(self, model_path: str, model: Module):
598
620
with override_quantization_status (
599
621
self .quantization_config , QuantizationStatus .FROZEN
600
622
):
601
-
602
623
names_to_scheme = apply_quantization_config (
603
624
model , self .quantization_config
604
625
)
0 commit comments