25
25
import torch .nn .functional as F
26
26
27
27
from modelopt .torch .opt .searcher import ForwardLoop
28
+ from modelopt .torch .utils import print_rank_0
28
29
from modelopt .torch .utils .distributed import ParallelState
29
30
from modelopt .torch .utils .network import bind_forward_method , unpatch_forward_method
30
31
@@ -368,13 +369,13 @@ def postprocess(module):
368
369
for name , module in model .named_modules ():
369
370
if is_quantized_linear (module ):
370
371
if not hasattr (module .input_quantizer , "_amax" ):
371
- print (f"Warning: { name } is not calibrated, skip smoothing" )
372
+ warnings . warn (f"{ name } is not calibrated, skip smoothing" )
372
373
continue
373
374
if module .input_quantizer .num_bits != 8 or module .weight_quantizer .num_bits != 8 :
374
- print (f"Warning: only int8 smoothing is supported, skip { name } " )
375
+ warnings . warn (f"Only int8 smoothing is supported, skip { name } " )
375
376
continue
376
377
if module .input_quantizer .axis != - 1 :
377
- print (f"Warning: only per-channel smoothing is supported, skip { name } " )
378
+ warnings . warn (f"Only per-channel smoothing is supported, skip { name } " )
378
379
continue
379
380
380
381
assert module .input_quantizer ._amax .numel () > 1 , (
@@ -385,52 +386,7 @@ def postprocess(module):
385
386
postprocess (module )
386
387
387
388
smoothed_modules += 1
388
- print (f"Smoothed { smoothed_modules } modules" )
389
-
390
-
391
- def _smoothquant_fasteval (model : nn .Module ):
392
- """Hacky implementation of Smooth-Quant. Copied from monkey-quant."""
393
- smoothed_modules = 0
394
- for name , module in model .named_modules ():
395
- if is_quantized_linear (module ):
396
- if not hasattr (module .input_quantizer , "_amax" ):
397
- print (f"Warning: { name } is not calibrated, skip smoothing" )
398
- continue
399
- if module .input_quantizer .num_bits != 8 or module .weight_quantizer .num_bits != 8 :
400
- print (f"Warning: only int8 smoothing is supported, skip { name } " )
401
- continue
402
- if module .input_quantizer .axis != - 1 :
403
- print (f"Warning: only per-channel smoothing is supported, skip { name } " )
404
- continue
405
-
406
- assert module .input_quantizer ._amax .numel () > 1
407
- delattr (module .weight_quantizer , "_amax" )
408
-
409
- # It is important to keep scaling math in fp32 to be numerically safe
410
- act_amax = module .input_quantizer .amax .float ()
411
- if act_amax .shape [0 ] == 1 :
412
- act_amax = act_amax .squeeze (0 )
413
- # If model is split across devices, this tensor may be on wrong one
414
- act_amax = act_amax .to (module .weight .device )
415
-
416
- max_bound = module .input_quantizer .maxbound
417
- scale_a = max_bound / act_amax
418
- # Some channel could have 0 amax which causes scale_a to overflow. Explicitly mask them out here
419
- epsilon = 1.0 / (1 << 31 )
420
- if act_amax .min () <= epsilon :
421
- zero_mask = act_amax <= epsilon
422
- scale_a [zero_mask ] = 1
423
- inv_scale_a = act_amax / max_bound
424
-
425
- module .weight .data .copy_ (
426
- (module .weight_quantizer (inv_scale_a * module .weight .float ()) * scale_a ).to (
427
- module .weight .dtype
428
- )
429
- )
430
- module .weight_quantizer .disable ()
431
-
432
- smoothed_modules += 1
433
- print (f"Smoothed { smoothed_modules } modules" )
389
+ print_rank_0 (f"Smoothed { smoothed_modules } modules" )
434
390
435
391
436
392
def awq (
@@ -481,7 +437,9 @@ def awq_lite(
481
437
See :class:`AWQLiteCalibConfig <modelopt.torch.quantization.config.AWQLiteCalibConfig>` for
482
438
details on the remaining arguments.
483
439
"""
484
- assert forward_loop is not None , "forward_loop must be provided for awq_lite"
440
+ if forward_loop is None :
441
+ warnings .warn ("forward_loop must be provided for awq_lite; skipping awq_lite" )
442
+ return
485
443
486
444
class AWQLiteHelper :
487
445
cache_mode : bool = False
@@ -493,11 +451,32 @@ def __init__(self, module, name):
493
451
self .num_search_steps = 0
494
452
self .block_size = _get_awq_quantizer_block_size (module .weight , module .weight_quantizer )
495
453
self .weight_scale = get_weight_scale (module .weight , self .block_size )
496
- self .loss = {k .item (): 0.0 for k in torch .arange (0 , 1.0 + alpha_step , alpha_step )}
454
+ self .loss = {
455
+ k .item (): torch .zeros ((), device = module .weight .device , dtype = torch .float32 )
456
+ for k in torch .arange (0 , 1.0 + alpha_step , alpha_step )
457
+ }
497
458
self .best_scale = None
498
459
self .best_alpha = None
499
460
self .is_input_quantized = module .input_quantizer .is_enabled
500
461
self .num_tokens = 0
462
+ self .module = module
463
+ self .is_enabled = True
464
+
465
+ def setup (self ):
466
+ module = self .module
467
+ bind_forward_method (module , forward , "_forward_no_awq" )
468
+ if module .input_quantizer .is_enabled :
469
+ module .input_quantizer .disable ()
470
+ if module .input_quantizer .axis not in [None , - 1 ]:
471
+ self .is_enabled = False
472
+ return
473
+ module .input_quantizer .axis = - 1
474
+
475
+ def cleanup (self ):
476
+ module = self .module
477
+ if hasattr (module , "_if_calib" ):
478
+ delattr (module , "_if_calib" )
479
+ unpatch_forward_method (module , "_forward_no_awq" )
501
480
502
481
def get_weight_scale (weight , block_size = None ):
503
482
org_shape = weight .shape
@@ -534,10 +513,13 @@ def get_scale(x_max, w_max, alpha, tensor_parallel_group=None):
534
513
def update_loss (self , out , out_actual , alpha ):
535
514
out_actual = out_actual [0 ] if isinstance (out_actual , tuple ) else out_actual
536
515
out = out [0 ] if isinstance (out , tuple ) else out
537
- loss = (out - out_actual ).float ().pow (2 ).mean (). item ()
516
+ loss = (out - out_actual ).float ().pow (2 ).mean ()
538
517
self .awq_lite .loss [alpha ] += loss
539
518
540
519
def update_best_params (self ):
520
+ if not self .awq_lite .is_enabled :
521
+ return
522
+ self .awq_lite .loss .update ({k : float (v ) for k , v in self .awq_lite .loss .items ()})
541
523
self .awq_lite .best_alpha = min (self .awq_lite .loss , key = self .awq_lite .loss .get )
542
524
self .awq_lite .best_scale = get_scale (
543
525
self .awq_lite .act_scale ,
@@ -560,7 +542,8 @@ def forward(self, input, *args, **kwargs):
560
542
out_actual = self ._forward_no_awq (input , * args , ** kwargs )
561
543
self .weight_quantizer .enable ()
562
544
563
- if input .numel () == 0 : # For MoEs, some experts might see 0 tokens
545
+ if input .numel () == 0 or not self .awq_lite .is_enabled :
546
+ # For MoEs, some experts might see 0 tokens
564
547
return out_actual
565
548
566
549
if AWQLiteHelper .cache_mode :
@@ -589,7 +572,6 @@ def forward(self, input, *args, **kwargs):
589
572
self .input_quantizer .pre_quant_scale = (1 / awq_scale ).to (self .weight .dtype )
590
573
self .weight_quantizer .pre_quant_scale = awq_scale .to (self .weight .dtype )
591
574
out = self ._forward_no_awq (input , * args , ** kwargs )
592
-
593
575
update_loss (self , out , out_actual , alpha )
594
576
595
577
self .awq_lite .num_search_steps += 1
@@ -601,19 +583,11 @@ def forward(self, input, *args, **kwargs):
601
583
if is_quantized_linear (module ) and module .weight_quantizer .is_enabled :
602
584
with enable_weight_access_and_writeback (module , model ):
603
585
module .awq_lite = AWQLiteHelper (module , name )
604
- bind_forward_method (module , forward , "_forward_no_awq" )
605
-
606
- if module .input_quantizer .is_enabled :
607
- module .input_quantizer .disable ()
608
- if module .input_quantizer .axis not in [None , - 1 ]:
609
- raise NotImplementedError (
610
- "input quantization needs to be per-tensor or None for AWQ algorithm"
611
- )
612
- module .input_quantizer .axis = - 1
586
+ module .awq_lite .setup ()
613
587
614
588
# Collect activation scale values
615
589
AWQLiteHelper .cache_mode = True
616
- print ( " Caching activation statistics for awq_lite ..." )
590
+ print_rank_0 ( "awq_lite: Caching activation statistics..." )
617
591
618
592
# Lets enable stats collection
619
593
# This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
@@ -631,22 +605,25 @@ def forward(self, input, *args, **kwargs):
631
605
and module .awq_lite .num_cache_steps > 0
632
606
):
633
607
module .awq_lite .act_scale = module .awq_lite .act_scale / module .awq_lite .num_cache_steps
608
+ if torch .any (torch .isnan (module .awq_lite .act_scale )) or torch .any (
609
+ torch .isnan (module .awq_lite .weight_scale )
610
+ ):
611
+ module .awq_lite .is_enabled = False
634
612
# Hack: MoEs forward all tokens through all experts if _if_calib is True
635
613
module ._if_calib = True
636
614
637
615
AWQLiteHelper .cache_mode = False
638
- print ( "Searching awq_lite parameters..." )
616
+ print_rank_0 ( " awq_lite: Searching parameters..." )
639
617
with torch .no_grad ():
640
618
forward_loop (model )
641
619
642
- def postprocess (module ):
620
+ def postprocess (module , name ):
643
621
update_best_params (module )
644
622
if hasattr (module .weight_quantizer , "_pre_quant_scale" ):
645
623
delattr (module .weight_quantizer , "_pre_quant_scale" )
646
624
if hasattr (module .input_quantizer , "_pre_quant_scale" ):
647
625
delattr (module .input_quantizer , "_pre_quant_scale" )
648
- if module .awq_lite .is_input_quantized :
649
- assert module .input_quantizer .amax is not None
626
+ if module .awq_lite .is_input_quantized and module .input_quantizer .amax is not None :
650
627
act_amax = module .input_quantizer .amax
651
628
# TODO: make this a buffer after we support only heterogeneous checkpointing for MCore
652
629
module .input_quantizer ._amax_for_smoothing = act_amax .cpu ()
@@ -655,25 +632,29 @@ def postprocess(module):
655
632
module .input_quantizer .amax = act_amax .amax ()
656
633
module .input_quantizer .enable ()
657
634
658
- apply_pre_quant_scale_and_smooth (module , 1.0 / module .awq_lite .best_scale )
635
+ if module .awq_lite .is_enabled :
636
+ apply_pre_quant_scale_and_smooth (module , 1.0 / module .awq_lite .best_scale )
637
+ else :
638
+ warnings .warn (f"awq_lite: Disabling for { name } , quantizing with max calibration." )
639
+ max_calibrate (module , lambda module : module .weight_quantizer (module .weight ))
659
640
660
641
for name , module in model .named_modules ():
661
642
if hasattr (module , "awq_lite" ):
662
- if module .awq_lite .num_cache_steps > 0 :
663
- assert module .awq_lite .num_search_steps > 0 , (
664
- "Calling `forward_loop(model)` the second time did not forward data through the"
665
- " model. Please provide a valid `forward_loop` function that can be used to"
643
+ if module .awq_lite .num_cache_steps == 0 :
644
+ module .awq_lite .is_enabled = False
645
+ elif module .awq_lite .num_search_steps == 0 :
646
+ module .awq_lite .is_enabled = False
647
+ warnings .warn (
648
+ "awq_lite: Calling `forward_loop(model)` the second time did not forward data through the"
649
+ f" { name } . Please provide a valid `forward_loop` function that can be used to"
666
650
" forward data through the model many times."
667
651
)
668
- with enable_weight_access_and_writeback (module , model ):
669
- postprocess (module )
652
+ with enable_weight_access_and_writeback (module , model ):
653
+ postprocess (module , name )
670
654
655
+ module .awq_lite .cleanup ()
671
656
if not debug :
672
657
delattr (module , "awq_lite" )
673
- if hasattr (module , "_if_calib" ):
674
- delattr (module , "_if_calib" )
675
-
676
- unpatch_forward_method (module , "_forward_no_awq" )
677
658
678
659
679
660
@torch .no_grad ()
@@ -858,7 +839,7 @@ def forward(name, self, input, *args, **kwargs):
858
839
with enable_weight_access_and_writeback (module , model ):
859
840
module .awq_clip = AWQClipHelper (module )
860
841
861
- print ( "Estimating awq_clip parameters..." )
842
+ print_rank_0 ( " awq_clip: Estimating parameters..." )
862
843
# Lets enable stats collection
863
844
# This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
864
845
enable_stats_collection (model )
@@ -919,7 +900,7 @@ def svdquant(
919
900
"""
920
901
921
902
def postprocess (module , name ):
922
- print (f"SVD { name } " )
903
+ print_rank_0 (f"SVD { name } " )
923
904
u , s , vt = torch .linalg .svd (module .weight .data .double ())
924
905
if u .shape [1 ] < lowrank or vt .shape [0 ] < lowrank :
925
906
warnings .warn (
0 commit comments