25
25
import torch .nn .functional as F
26
26
27
27
from modelopt .torch .opt .searcher import ForwardLoop
28
+ from modelopt .torch .utils import print_rank_0
28
29
from modelopt .torch .utils .distributed import ParallelState
29
30
from modelopt .torch .utils .network import bind_forward_method , unpatch_forward_method
30
31
@@ -368,13 +369,13 @@ def postprocess(module):
368
369
for name , module in model .named_modules ():
369
370
if is_quantized_linear (module ):
370
371
if not hasattr (module .input_quantizer , "_amax" ):
371
- print (f"Warning: { name } is not calibrated, skip smoothing" )
372
+ warnings . warn (f"{ name } is not calibrated, skip smoothing" )
372
373
continue
373
374
if module .input_quantizer .num_bits != 8 or module .weight_quantizer .num_bits != 8 :
374
- print (f"Warning: only int8 smoothing is supported, skip { name } " )
375
+ warnings . warn (f"Only int8 smoothing is supported, skip { name } " )
375
376
continue
376
377
if module .input_quantizer .axis != - 1 :
377
- print (f"Warning: only per-channel smoothing is supported, skip { name } " )
378
+ warnings . warn (f"Only per-channel smoothing is supported, skip { name } " )
378
379
continue
379
380
380
381
assert module .input_quantizer ._amax .numel () > 1 , (
@@ -385,52 +386,7 @@ def postprocess(module):
385
386
postprocess (module )
386
387
387
388
smoothed_modules += 1
388
- print (f"Smoothed { smoothed_modules } modules" )
389
-
390
-
391
- def _smoothquant_fasteval (model : nn .Module ):
392
- """Hacky implementation of Smooth-Quant. Copied from monkey-quant."""
393
- smoothed_modules = 0
394
- for name , module in model .named_modules ():
395
- if is_quantized_linear (module ):
396
- if not hasattr (module .input_quantizer , "_amax" ):
397
- print (f"Warning: { name } is not calibrated, skip smoothing" )
398
- continue
399
- if module .input_quantizer .num_bits != 8 or module .weight_quantizer .num_bits != 8 :
400
- print (f"Warning: only int8 smoothing is supported, skip { name } " )
401
- continue
402
- if module .input_quantizer .axis != - 1 :
403
- print (f"Warning: only per-channel smoothing is supported, skip { name } " )
404
- continue
405
-
406
- assert module .input_quantizer ._amax .numel () > 1
407
- delattr (module .weight_quantizer , "_amax" )
408
-
409
- # It is important to keep scaling math in fp32 to be numerically safe
410
- act_amax = module .input_quantizer .amax .float ()
411
- if act_amax .shape [0 ] == 1 :
412
- act_amax = act_amax .squeeze (0 )
413
- # If model is split across devices, this tensor may be on wrong one
414
- act_amax = act_amax .to (module .weight .device )
415
-
416
- max_bound = module .input_quantizer .maxbound
417
- scale_a = max_bound / act_amax
418
- # Some channel could have 0 amax which causes scale_a to overflow. Explicitly mask them out here
419
- epsilon = 1.0 / (1 << 31 )
420
- if act_amax .min () <= epsilon :
421
- zero_mask = act_amax <= epsilon
422
- scale_a [zero_mask ] = 1
423
- inv_scale_a = act_amax / max_bound
424
-
425
- module .weight .data .copy_ (
426
- (module .weight_quantizer (inv_scale_a * module .weight .float ()) * scale_a ).to (
427
- module .weight .dtype
428
- )
429
- )
430
- module .weight_quantizer .disable ()
431
-
432
- smoothed_modules += 1
433
- print (f"Smoothed { smoothed_modules } modules" )
389
+ print_rank_0 (f"Smoothed { smoothed_modules } modules" )
434
390
435
391
436
392
def awq (
@@ -481,7 +437,9 @@ def awq_lite(
481
437
See :class:`AWQLiteCalibConfig <modelopt.torch.quantization.config.AWQLiteCalibConfig>` for
482
438
details on the remaining arguments.
483
439
"""
484
- assert forward_loop is not None , "forward_loop must be provided for awq_lite"
440
+ if forward_loop is None :
441
+ warnings .warn ("forward_loop must be provided for awq_lite; skipping awq_lite" )
442
+ return
485
443
486
444
class AWQLiteHelper :
487
445
cache_mode : bool = False
@@ -493,11 +451,32 @@ def __init__(self, module, name):
493
451
self .num_search_steps = 0
494
452
self .block_size = _get_awq_quantizer_block_size (module .weight , module .weight_quantizer )
495
453
self .weight_scale = get_weight_scale (module .weight , self .block_size )
496
- self .loss = {k .item (): 0.0 for k in torch .arange (0 , 1.0 + alpha_step , alpha_step )}
454
+ self .loss = {
455
+ k .item (): torch .zeros ((), device = module .weight .device , dtype = torch .float32 )
456
+ for k in torch .arange (0 , 1.0 + alpha_step , alpha_step )
457
+ }
497
458
self .best_scale = None
498
459
self .best_alpha = None
499
460
self .is_input_quantized = module .input_quantizer .is_enabled
500
461
self .num_tokens = 0
462
+ self .module = module
463
+ self .is_enabled = True
464
+
465
+ def setup (self ):
466
+ module = self .module
467
+ bind_forward_method (module , forward , "_forward_no_awq" )
468
+ if module .input_quantizer .is_enabled :
469
+ module .input_quantizer .disable ()
470
+ if module .input_quantizer .axis not in [None , - 1 ]:
471
+ self .is_enabled = False
472
+ return
473
+ module .input_quantizer .axis = - 1
474
+
475
+ def cleanup (self ):
476
+ module = self .module
477
+ if hasattr (module , "_if_calib" ):
478
+ delattr (module , "_if_calib" )
479
+ unpatch_forward_method (module , "_forward_no_awq" )
501
480
502
481
def get_weight_scale (weight , block_size = None ):
503
482
org_shape = weight .shape
@@ -538,6 +517,8 @@ def update_loss(self, out, out_actual, alpha):
538
517
self .awq_lite .loss [alpha ] += loss
539
518
540
519
def update_best_params (self ):
520
+ if not self .awq_lite .is_enabled :
521
+ return
541
522
self .awq_lite .best_alpha = min (self .awq_lite .loss , key = self .awq_lite .loss .get )
542
523
self .awq_lite .best_scale = get_scale (
543
524
self .awq_lite .act_scale ,
@@ -560,7 +541,8 @@ def forward(self, input, *args, **kwargs):
560
541
out_actual = self ._forward_no_awq (input , * args , ** kwargs )
561
542
self .weight_quantizer .enable ()
562
543
563
- if input .numel () == 0 : # For MoEs, some experts might see 0 tokens
544
+ if input .numel () == 0 or not self .awq_lite .is_enabled :
545
+ # For MoEs, some experts might see 0 tokens
564
546
return out_actual
565
547
566
548
if AWQLiteHelper .cache_mode :
@@ -588,6 +570,23 @@ def forward(self, input, *args, **kwargs):
588
570
)
589
571
self .input_quantizer .pre_quant_scale = (1 / awq_scale ).to (self .weight .dtype )
590
572
self .weight_quantizer .pre_quant_scale = awq_scale .to (self .weight .dtype )
573
+
574
+ disable_awq = False
575
+ for tq in [self .input_quantizer , self .weight_quantizer ]:
576
+ for attr in ["_pre_quant_scale" , "_amax" ]:
577
+ if not tq .validate_attr (attr_name = attr ):
578
+ disable_awq = True
579
+ warnings .warn (
580
+ f"awq_lite: { attr } is not valid for { self .awq_lite .name } , skipping awq_lite"
581
+ )
582
+ break
583
+ if disable_awq :
584
+ break
585
+
586
+ if disable_awq :
587
+ self .awq_lite .is_enabled = False
588
+ return out_actual
589
+
591
590
out = self ._forward_no_awq (input , * args , ** kwargs )
592
591
593
592
update_loss (self , out , out_actual , alpha )
@@ -601,19 +600,11 @@ def forward(self, input, *args, **kwargs):
601
600
if is_quantized_linear (module ) and module .weight_quantizer .is_enabled :
602
601
with enable_weight_access_and_writeback (module , model ):
603
602
module .awq_lite = AWQLiteHelper (module , name )
604
- bind_forward_method (module , forward , "_forward_no_awq" )
605
-
606
- if module .input_quantizer .is_enabled :
607
- module .input_quantizer .disable ()
608
- if module .input_quantizer .axis not in [None , - 1 ]:
609
- raise NotImplementedError (
610
- "input quantization needs to be per-tensor or None for AWQ algorithm"
611
- )
612
- module .input_quantizer .axis = - 1
603
+ module .awq_lite .setup ()
613
604
614
605
# Collect activation scale values
615
606
AWQLiteHelper .cache_mode = True
616
- print ( " Caching activation statistics for awq_lite ..." )
607
+ print_rank_0 ( "awq_lite: Caching activation statistics..." )
617
608
618
609
# Lets enable stats collection
619
610
# This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
@@ -635,18 +626,17 @@ def forward(self, input, *args, **kwargs):
635
626
module ._if_calib = True
636
627
637
628
AWQLiteHelper .cache_mode = False
638
- print ( "Searching awq_lite parameters..." )
629
+ print_rank_0 ( " awq_lite: Searching parameters..." )
639
630
with torch .no_grad ():
640
631
forward_loop (model )
641
632
642
- def postprocess (module ):
633
+ def postprocess (module , name ):
643
634
update_best_params (module )
644
635
if hasattr (module .weight_quantizer , "_pre_quant_scale" ):
645
636
delattr (module .weight_quantizer , "_pre_quant_scale" )
646
637
if hasattr (module .input_quantizer , "_pre_quant_scale" ):
647
638
delattr (module .input_quantizer , "_pre_quant_scale" )
648
- if module .awq_lite .is_input_quantized :
649
- assert module .input_quantizer .amax is not None
639
+ if module .awq_lite .is_input_quantized and module .input_quantizer .amax is not None :
650
640
act_amax = module .input_quantizer .amax
651
641
# TODO: make this a buffer after we support only heterogeneous checkpointing for MCore
652
642
module .input_quantizer ._amax_for_smoothing = act_amax .cpu ()
@@ -655,25 +645,29 @@ def postprocess(module):
655
645
module .input_quantizer .amax = act_amax .amax ()
656
646
module .input_quantizer .enable ()
657
647
658
- apply_pre_quant_scale_and_smooth (module , 1.0 / module .awq_lite .best_scale )
648
+ if module .awq_lite .is_enabled :
649
+ apply_pre_quant_scale_and_smooth (module , 1.0 / module .awq_lite .best_scale )
650
+ else :
651
+ warnings .warn (f"awq_lite: Disabling for { name } , quantizing with max calibration." )
652
+ max_calibrate (module , lambda module : module .weight_quantizer (module .weight ))
659
653
660
654
for name , module in model .named_modules ():
661
655
if hasattr (module , "awq_lite" ):
662
- if module .awq_lite .num_cache_steps > 0 :
663
- assert module .awq_lite .num_search_steps > 0 , (
664
- "Calling `forward_loop(model)` the second time did not forward data through the"
665
- " model. Please provide a valid `forward_loop` function that can be used to"
656
+ if module .awq_lite .num_cache_steps == 0 :
657
+ module .awq_lite .is_enabled = False
658
+ elif module .awq_lite .num_search_steps == 0 :
659
+ module .awq_lite .is_enabled = False
660
+ warnings .warn (
661
+ "awq_lite: Calling `forward_loop(model)` the second time did not forward data through the"
662
+ f" { name } . Please provide a valid `forward_loop` function that can be used to"
666
663
" forward data through the model many times."
667
664
)
668
- with enable_weight_access_and_writeback (module , model ):
669
- postprocess (module )
665
+ with enable_weight_access_and_writeback (module , model ):
666
+ postprocess (module , name )
670
667
668
+ module .awq_lite .cleanup ()
671
669
if not debug :
672
670
delattr (module , "awq_lite" )
673
- if hasattr (module , "_if_calib" ):
674
- delattr (module , "_if_calib" )
675
-
676
- unpatch_forward_method (module , "_forward_no_awq" )
677
671
678
672
679
673
@torch .no_grad ()
@@ -858,7 +852,7 @@ def forward(name, self, input, *args, **kwargs):
858
852
with enable_weight_access_and_writeback (module , model ):
859
853
module .awq_clip = AWQClipHelper (module )
860
854
861
- print ( "Estimating awq_clip parameters..." )
855
+ print_rank_0 ( " awq_clip: Estimating parameters..." )
862
856
# Lets enable stats collection
863
857
# This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
864
858
enable_stats_collection (model )
@@ -919,7 +913,7 @@ def svdquant(
919
913
"""
920
914
921
915
def postprocess (module , name ):
922
- print (f"SVD { name } " )
916
+ print_rank_0 (f"SVD { name } " )
923
917
u , s , vt = torch .linalg .svd (module .weight .data .double ())
924
918
if u .shape [1 ] < lowrank or vt .shape [0 ] < lowrank :
925
919
warnings .warn (
0 commit comments