25
25
import torch .nn .functional as F
26
26
27
27
from modelopt .torch .opt .searcher import ForwardLoop
28
+ from modelopt .torch .utils import print_rank_0
28
29
from modelopt .torch .utils .distributed import ParallelState
29
30
from modelopt .torch .utils .network import bind_forward_method , unpatch_forward_method
30
31
@@ -368,13 +369,13 @@ def postprocess(module):
368
369
for name , module in model .named_modules ():
369
370
if is_quantized_linear (module ):
370
371
if not hasattr (module .input_quantizer , "_amax" ):
371
- print (f"Warning: { name } is not calibrated, skip smoothing" )
372
+ print_rank_0 (f"Warning: { name } is not calibrated, skip smoothing" )
372
373
continue
373
374
if module .input_quantizer .num_bits != 8 or module .weight_quantizer .num_bits != 8 :
374
- print (f"Warning: only int8 smoothing is supported, skip { name } " )
375
+ print_rank_0 (f"Warning: only int8 smoothing is supported, skip { name } " )
375
376
continue
376
377
if module .input_quantizer .axis != - 1 :
377
- print (f"Warning: only per-channel smoothing is supported, skip { name } " )
378
+ print_rank_0 (f"Warning: only per-channel smoothing is supported, skip { name } " )
378
379
continue
379
380
380
381
assert module .input_quantizer ._amax .numel () > 1 , (
@@ -385,52 +386,7 @@ def postprocess(module):
385
386
postprocess (module )
386
387
387
388
smoothed_modules += 1
388
- print (f"Smoothed { smoothed_modules } modules" )
389
-
390
-
391
- def _smoothquant_fasteval (model : nn .Module ):
392
- """Hacky implementation of Smooth-Quant. Copied from monkey-quant."""
393
- smoothed_modules = 0
394
- for name , module in model .named_modules ():
395
- if is_quantized_linear (module ):
396
- if not hasattr (module .input_quantizer , "_amax" ):
397
- print (f"Warning: { name } is not calibrated, skip smoothing" )
398
- continue
399
- if module .input_quantizer .num_bits != 8 or module .weight_quantizer .num_bits != 8 :
400
- print (f"Warning: only int8 smoothing is supported, skip { name } " )
401
- continue
402
- if module .input_quantizer .axis != - 1 :
403
- print (f"Warning: only per-channel smoothing is supported, skip { name } " )
404
- continue
405
-
406
- assert module .input_quantizer ._amax .numel () > 1
407
- delattr (module .weight_quantizer , "_amax" )
408
-
409
- # It is important to keep scaling math in fp32 to be numerically safe
410
- act_amax = module .input_quantizer .amax .float ()
411
- if act_amax .shape [0 ] == 1 :
412
- act_amax = act_amax .squeeze (0 )
413
- # If model is split across devices, this tensor may be on wrong one
414
- act_amax = act_amax .to (module .weight .device )
415
-
416
- max_bound = module .input_quantizer .maxbound
417
- scale_a = max_bound / act_amax
418
- # Some channel could have 0 amax which causes scale_a to overflow. Explicitly mask them out here
419
- epsilon = 1.0 / (1 << 31 )
420
- if act_amax .min () <= epsilon :
421
- zero_mask = act_amax <= epsilon
422
- scale_a [zero_mask ] = 1
423
- inv_scale_a = act_amax / max_bound
424
-
425
- module .weight .data .copy_ (
426
- (module .weight_quantizer (inv_scale_a * module .weight .float ()) * scale_a ).to (
427
- module .weight .dtype
428
- )
429
- )
430
- module .weight_quantizer .disable ()
431
-
432
- smoothed_modules += 1
433
- print (f"Smoothed { smoothed_modules } modules" )
389
+ print_rank_0 (f"Smoothed { smoothed_modules } modules" )
434
390
435
391
436
392
def awq (
@@ -481,7 +437,9 @@ def awq_lite(
481
437
See :class:`AWQLiteCalibConfig <modelopt.torch.quantization.config.AWQLiteCalibConfig>` for
482
438
details on the remaining arguments.
483
439
"""
484
- assert forward_loop is not None , "forward_loop must be provided for awq_lite"
440
+ if forward_loop is None :
441
+ warnings .warn ("forward_loop must be provided for awq_lite; skipping awq_lite" )
442
+ return
485
443
486
444
class AWQLiteHelper :
487
445
cache_mode : bool = False
@@ -498,6 +456,25 @@ def __init__(self, module, name):
498
456
self .best_alpha = None
499
457
self .is_input_quantized = module .input_quantizer .is_enabled
500
458
self .num_tokens = 0
459
+ self .module = module
460
+ self .setup ()
461
+ self .is_enabled = True
462
+
463
+ def setup (self ):
464
+ module = self .module
465
+ bind_forward_method (module , forward , "_forward_no_awq" )
466
+ if module .input_quantizer .is_enabled :
467
+ module .input_quantizer .disable ()
468
+ if module .input_quantizer .axis not in [None , - 1 ]:
469
+ self .is_enabled = False
470
+ return
471
+ module .input_quantizer .axis = - 1
472
+
473
+ def cleanup (self ):
474
+ module = self .module
475
+ if hasattr (module , "_if_calib" ):
476
+ delattr (module , "_if_calib" )
477
+ unpatch_forward_method (module , "_forward_no_awq" )
501
478
502
479
def get_weight_scale (weight , block_size = None ):
503
480
org_shape = weight .shape
@@ -538,6 +515,8 @@ def update_loss(self, out, out_actual, alpha):
538
515
self .awq_lite .loss [alpha ] += loss
539
516
540
517
def update_best_params (self ):
518
+ if not module .awq_lite .is_enabled :
519
+ return
541
520
self .awq_lite .best_alpha = min (self .awq_lite .loss , key = self .awq_lite .loss .get )
542
521
self .awq_lite .best_scale = get_scale (
543
522
self .awq_lite .act_scale ,
@@ -560,7 +539,8 @@ def forward(self, input, *args, **kwargs):
560
539
out_actual = self ._forward_no_awq (input , * args , ** kwargs )
561
540
self .weight_quantizer .enable ()
562
541
563
- if input .numel () == 0 : # For MoEs, some experts might see 0 tokens
542
+ if input .numel () == 0 or not self .awq_lite .is_enabled :
543
+ # For MoEs, some experts might see 0 tokens
564
544
return out_actual
565
545
566
546
if AWQLiteHelper .cache_mode :
@@ -588,6 +568,23 @@ def forward(self, input, *args, **kwargs):
588
568
)
589
569
self .input_quantizer .pre_quant_scale = (1 / awq_scale ).to (self .weight .dtype )
590
570
self .weight_quantizer .pre_quant_scale = awq_scale .to (self .weight .dtype )
571
+
572
+ disable_awq = False
573
+ for tq in [self .input_quantizer , self .weight_quantizer ]:
574
+ for attr in ["_pre_quant_scale" , "_amax" ]:
575
+ if not tq .validate_attr (attr_name = attr ):
576
+ disable_awq = True
577
+ warnings .warn (
578
+ f"awq_lite: { attr } is not valid for { self .awq_lite .name } , skipping awq_lite"
579
+ )
580
+ break
581
+ if disable_awq :
582
+ break
583
+
584
+ if disable_awq :
585
+ self .awq_lite .is_enabled = False
586
+ return out_actual
587
+
591
588
out = self ._forward_no_awq (input , * args , ** kwargs )
592
589
593
590
update_loss (self , out , out_actual , alpha )
@@ -601,19 +598,10 @@ def forward(self, input, *args, **kwargs):
601
598
if is_quantized_linear (module ) and module .weight_quantizer .is_enabled :
602
599
with enable_weight_access_and_writeback (module , model ):
603
600
module .awq_lite = AWQLiteHelper (module , name )
604
- bind_forward_method (module , forward , "_forward_no_awq" )
605
-
606
- if module .input_quantizer .is_enabled :
607
- module .input_quantizer .disable ()
608
- if module .input_quantizer .axis not in [None , - 1 ]:
609
- raise NotImplementedError (
610
- "input quantization needs to be per-tensor or None for AWQ algorithm"
611
- )
612
- module .input_quantizer .axis = - 1
613
601
614
602
# Collect activation scale values
615
603
AWQLiteHelper .cache_mode = True
616
- print ( " Caching activation statistics for awq_lite ..." )
604
+ print_rank_0 ( "awq_lite: Caching activation statistics..." )
617
605
618
606
# Lets enable stats collection
619
607
# This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
@@ -635,18 +623,17 @@ def forward(self, input, *args, **kwargs):
635
623
module ._if_calib = True
636
624
637
625
AWQLiteHelper .cache_mode = False
638
- print ( "Searching awq_lite parameters..." )
626
+ print_rank_0 ( " awq_lite: Searching parameters..." )
639
627
with torch .no_grad ():
640
628
forward_loop (model )
641
629
642
- def postprocess (module ):
630
+ def postprocess (module , name ):
643
631
update_best_params (module )
644
632
if hasattr (module .weight_quantizer , "_pre_quant_scale" ):
645
633
delattr (module .weight_quantizer , "_pre_quant_scale" )
646
634
if hasattr (module .input_quantizer , "_pre_quant_scale" ):
647
635
delattr (module .input_quantizer , "_pre_quant_scale" )
648
- if module .awq_lite .is_input_quantized :
649
- assert module .input_quantizer .amax is not None
636
+ if module .awq_lite .is_input_quantized and module .input_quantizer .amax is not None :
650
637
act_amax = module .input_quantizer .amax
651
638
# TODO: make this a buffer after we support only heterogeneous checkpointing for MCore
652
639
module .input_quantizer ._amax_for_smoothing = act_amax .cpu ()
@@ -655,25 +642,29 @@ def postprocess(module):
655
642
module .input_quantizer .amax = act_amax .amax ()
656
643
module .input_quantizer .enable ()
657
644
658
- apply_pre_quant_scale_and_smooth (module , 1.0 / module .awq_lite .best_scale )
645
+ if module .awq_lite .is_enabled :
646
+ apply_pre_quant_scale_and_smooth (module , 1.0 / module .awq_lite .best_scale )
647
+ else :
648
+ warnings .warn (f"awq_lite: Disabling for { name } , quantizing with max calibration." )
649
+ max_calibrate (module , lambda module : module .weight_quantizer (module .weight ))
659
650
660
651
for name , module in model .named_modules ():
661
652
if hasattr (module , "awq_lite" ):
662
- if module .awq_lite .num_cache_steps > 0 :
663
- assert module .awq_lite .num_search_steps > 0 , (
664
- "Calling `forward_loop(model)` the second time did not forward data through the"
665
- " model. Please provide a valid `forward_loop` function that can be used to"
653
+ if module .awq_lite .num_cache_steps == 0 :
654
+ module .awq_lite .is_enabled = False
655
+ elif module .awq_lite .num_search_steps == 0 :
656
+ module .awq_lite .is_enabled = False
657
+ warnings .warn (
658
+ "awq_lite: Calling `forward_loop(model)` the second time did not forward data through the"
659
+ f" { name } . Please provide a valid `forward_loop` function that can be used to"
666
660
" forward data through the model many times."
667
661
)
668
- with enable_weight_access_and_writeback (module , model ):
669
- postprocess (module )
662
+ with enable_weight_access_and_writeback (module , model ):
663
+ postprocess (module , name )
670
664
665
+ module .awq_lite .cleanup ()
671
666
if not debug :
672
667
delattr (module , "awq_lite" )
673
- if hasattr (module , "_if_calib" ):
674
- delattr (module , "_if_calib" )
675
-
676
- unpatch_forward_method (module , "_forward_no_awq" )
677
668
678
669
679
670
@torch .no_grad ()
@@ -858,7 +849,7 @@ def forward(name, self, input, *args, **kwargs):
858
849
with enable_weight_access_and_writeback (module , model ):
859
850
module .awq_clip = AWQClipHelper (module )
860
851
861
- print ( "Estimating awq_clip parameters..." )
852
+ print_rank_0 ( " awq_clip: Estimating parameters..." )
862
853
# Lets enable stats collection
863
854
# This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
864
855
enable_stats_collection (model )
@@ -919,7 +910,7 @@ def svdquant(
919
910
"""
920
911
921
912
def postprocess (module , name ):
922
- print (f"SVD { name } " )
913
+ print_rank_0 (f"SVD { name } " )
923
914
u , s , vt = torch .linalg .svd (module .weight .data .double ())
924
915
if u .shape [1 ] < lowrank or vt .shape [0 ] < lowrank :
925
916
warnings .warn (
0 commit comments