25
25
import torch .nn .functional as F
26
26
27
27
from modelopt .torch .opt .searcher import ForwardLoop
28
+ from modelopt .torch .utils import print_rank_0
28
29
from modelopt .torch .utils .distributed import ParallelState
29
30
from modelopt .torch .utils .network import bind_forward_method , unpatch_forward_method
30
31
@@ -368,13 +369,13 @@ def postprocess(module):
368
369
for name , module in model .named_modules ():
369
370
if is_quantized_linear (module ):
370
371
if not hasattr (module .input_quantizer , "_amax" ):
371
- print (f"Warning: { name } is not calibrated, skip smoothing" )
372
+ warnings . warn (f"{ name } is not calibrated, skip smoothing" )
372
373
continue
373
374
if module .input_quantizer .num_bits != 8 or module .weight_quantizer .num_bits != 8 :
374
- print (f"Warning: only int8 smoothing is supported, skip { name } " )
375
+ warnings . warn (f"Only int8 smoothing is supported, skip { name } " )
375
376
continue
376
377
if module .input_quantizer .axis != - 1 :
377
- print (f"Warning: only per-channel smoothing is supported, skip { name } " )
378
+ warnings . warn (f"Only per-channel smoothing is supported, skip { name } " )
378
379
continue
379
380
380
381
assert module .input_quantizer ._amax .numel () > 1 , (
@@ -385,52 +386,7 @@ def postprocess(module):
385
386
postprocess (module )
386
387
387
388
smoothed_modules += 1
388
- print (f"Smoothed { smoothed_modules } modules" )
389
-
390
-
391
- def _smoothquant_fasteval (model : nn .Module ):
392
- """Hacky implementation of Smooth-Quant. Copied from monkey-quant."""
393
- smoothed_modules = 0
394
- for name , module in model .named_modules ():
395
- if is_quantized_linear (module ):
396
- if not hasattr (module .input_quantizer , "_amax" ):
397
- print (f"Warning: { name } is not calibrated, skip smoothing" )
398
- continue
399
- if module .input_quantizer .num_bits != 8 or module .weight_quantizer .num_bits != 8 :
400
- print (f"Warning: only int8 smoothing is supported, skip { name } " )
401
- continue
402
- if module .input_quantizer .axis != - 1 :
403
- print (f"Warning: only per-channel smoothing is supported, skip { name } " )
404
- continue
405
-
406
- assert module .input_quantizer ._amax .numel () > 1
407
- delattr (module .weight_quantizer , "_amax" )
408
-
409
- # It is important to keep scaling math in fp32 to be numerically safe
410
- act_amax = module .input_quantizer .amax .float ()
411
- if act_amax .shape [0 ] == 1 :
412
- act_amax = act_amax .squeeze (0 )
413
- # If model is split across devices, this tensor may be on wrong one
414
- act_amax = act_amax .to (module .weight .device )
415
-
416
- max_bound = module .input_quantizer .maxbound
417
- scale_a = max_bound / act_amax
418
- # Some channel could have 0 amax which causes scale_a to overflow. Explicitly mask them out here
419
- epsilon = 1.0 / (1 << 31 )
420
- if act_amax .min () <= epsilon :
421
- zero_mask = act_amax <= epsilon
422
- scale_a [zero_mask ] = 1
423
- inv_scale_a = act_amax / max_bound
424
-
425
- module .weight .data .copy_ (
426
- (module .weight_quantizer (inv_scale_a * module .weight .float ()) * scale_a ).to (
427
- module .weight .dtype
428
- )
429
- )
430
- module .weight_quantizer .disable ()
431
-
432
- smoothed_modules += 1
433
- print (f"Smoothed { smoothed_modules } modules" )
389
+ print_rank_0 (f"Smoothed { smoothed_modules } modules" )
434
390
435
391
436
392
def awq (
@@ -481,7 +437,9 @@ def awq_lite(
481
437
See :class:`AWQLiteCalibConfig <modelopt.torch.quantization.config.AWQLiteCalibConfig>` for
482
438
details on the remaining arguments.
483
439
"""
484
- assert forward_loop is not None , "forward_loop must be provided for awq_lite"
440
+ if forward_loop is None :
441
+ warnings .warn ("forward_loop must be provided for awq_lite; skipping awq_lite" )
442
+ return
485
443
486
444
class AWQLiteHelper :
487
445
cache_mode : bool = False
@@ -493,11 +451,32 @@ def __init__(self, module, name):
493
451
self .num_search_steps = 0
494
452
self .block_size = _get_awq_quantizer_block_size (module .weight , module .weight_quantizer )
495
453
self .weight_scale = get_weight_scale (module .weight , self .block_size )
496
- self .loss = {k .item (): 0.0 for k in torch .arange (0 , 1.0 + alpha_step , alpha_step )}
454
+ self .loss = {
455
+ k .item (): torch .zeros ((), device = module .weight .device , dtype = torch .float32 )
456
+ for k in torch .arange (0 , 1.0 + alpha_step , alpha_step )
457
+ }
497
458
self .best_scale = None
498
459
self .best_alpha = None
499
460
self .is_input_quantized = module .input_quantizer .is_enabled
500
461
self .num_tokens = 0
462
+ self .module = module
463
+ self .is_enabled = True
464
+
465
+ def setup (self ):
466
+ module = self .module
467
+ bind_forward_method (module , forward , "_forward_no_awq" )
468
+ if module .input_quantizer .is_enabled :
469
+ module .input_quantizer .disable ()
470
+ if module .input_quantizer .axis not in [None , - 1 ]:
471
+ self .is_enabled = False
472
+ return
473
+ module .input_quantizer .axis = - 1
474
+
475
+ def cleanup (self ):
476
+ module = self .module
477
+ if hasattr (module , "_if_calib" ):
478
+ delattr (module , "_if_calib" )
479
+ unpatch_forward_method (module , "_forward_no_awq" )
501
480
502
481
def get_weight_scale (weight , block_size = None ):
503
482
org_shape = weight .shape
@@ -534,10 +513,13 @@ def get_scale(x_max, w_max, alpha, tensor_parallel_group=None):
534
513
def update_loss (self , out , out_actual , alpha ):
535
514
out_actual = out_actual [0 ] if isinstance (out_actual , tuple ) else out_actual
536
515
out = out [0 ] if isinstance (out , tuple ) else out
537
- loss = (out - out_actual ).float ().pow (2 ).mean (). item ()
516
+ loss = (out - out_actual ).float ().pow (2 ).mean ()
538
517
self .awq_lite .loss [alpha ] += loss
539
518
540
519
def update_best_params (self ):
520
+ if not self .awq_lite .is_enabled :
521
+ return
522
+ self .awq_lite .loss .update ({k : float (v ) for k , v in self .awq_lite .loss .items ()})
541
523
self .awq_lite .best_alpha = min (self .awq_lite .loss , key = self .awq_lite .loss .get )
542
524
self .awq_lite .best_scale = get_scale (
543
525
self .awq_lite .act_scale ,
@@ -560,7 +542,8 @@ def forward(self, input, *args, **kwargs):
560
542
out_actual = self ._forward_no_awq (input , * args , ** kwargs )
561
543
self .weight_quantizer .enable ()
562
544
563
- if input .numel () == 0 : # For MoEs, some experts might see 0 tokens
545
+ if input .numel () == 0 or not self .awq_lite .is_enabled :
546
+ # For MoEs, some experts might see 0 tokens
564
547
return out_actual
565
548
566
549
if AWQLiteHelper .cache_mode :
@@ -588,6 +571,23 @@ def forward(self, input, *args, **kwargs):
588
571
)
589
572
self .input_quantizer .pre_quant_scale = (1 / awq_scale ).to (self .weight .dtype )
590
573
self .weight_quantizer .pre_quant_scale = awq_scale .to (self .weight .dtype )
574
+
575
+ disable_awq = False
576
+ for tq in [self .input_quantizer , self .weight_quantizer ]:
577
+ for attr in ["_pre_quant_scale" , "_amax" ]:
578
+ if not tq .validate_attr (attr_name = attr ):
579
+ disable_awq = True
580
+ warnings .warn (
581
+ f"awq_lite: { attr } is not valid for { self .awq_lite .name } , skipping awq_lite"
582
+ )
583
+ break
584
+ if disable_awq :
585
+ break
586
+
587
+ if disable_awq :
588
+ self .awq_lite .is_enabled = False
589
+ return out_actual
590
+
591
591
out = self ._forward_no_awq (input , * args , ** kwargs )
592
592
593
593
update_loss (self , out , out_actual , alpha )
@@ -601,19 +601,11 @@ def forward(self, input, *args, **kwargs):
601
601
if is_quantized_linear (module ) and module .weight_quantizer .is_enabled :
602
602
with enable_weight_access_and_writeback (module , model ):
603
603
module .awq_lite = AWQLiteHelper (module , name )
604
- bind_forward_method (module , forward , "_forward_no_awq" )
605
-
606
- if module .input_quantizer .is_enabled :
607
- module .input_quantizer .disable ()
608
- if module .input_quantizer .axis not in [None , - 1 ]:
609
- raise NotImplementedError (
610
- "input quantization needs to be per-tensor or None for AWQ algorithm"
611
- )
612
- module .input_quantizer .axis = - 1
604
+ module .awq_lite .setup ()
613
605
614
606
# Collect activation scale values
615
607
AWQLiteHelper .cache_mode = True
616
- print ( " Caching activation statistics for awq_lite ..." )
608
+ print_rank_0 ( "awq_lite: Caching activation statistics..." )
617
609
618
610
# Lets enable stats collection
619
611
# This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
@@ -635,18 +627,17 @@ def forward(self, input, *args, **kwargs):
635
627
module ._if_calib = True
636
628
637
629
AWQLiteHelper .cache_mode = False
638
- print ( "Searching awq_lite parameters..." )
630
+ print_rank_0 ( " awq_lite: Searching parameters..." )
639
631
with torch .no_grad ():
640
632
forward_loop (model )
641
633
642
- def postprocess (module ):
634
+ def postprocess (module , name ):
643
635
update_best_params (module )
644
636
if hasattr (module .weight_quantizer , "_pre_quant_scale" ):
645
637
delattr (module .weight_quantizer , "_pre_quant_scale" )
646
638
if hasattr (module .input_quantizer , "_pre_quant_scale" ):
647
639
delattr (module .input_quantizer , "_pre_quant_scale" )
648
- if module .awq_lite .is_input_quantized :
649
- assert module .input_quantizer .amax is not None
640
+ if module .awq_lite .is_input_quantized and module .input_quantizer .amax is not None :
650
641
act_amax = module .input_quantizer .amax
651
642
# TODO: make this a buffer after we support only heterogeneous checkpointing for MCore
652
643
module .input_quantizer ._amax_for_smoothing = act_amax .cpu ()
@@ -655,25 +646,29 @@ def postprocess(module):
655
646
module .input_quantizer .amax = act_amax .amax ()
656
647
module .input_quantizer .enable ()
657
648
658
- apply_pre_quant_scale_and_smooth (module , 1.0 / module .awq_lite .best_scale )
649
+ if module .awq_lite .is_enabled :
650
+ apply_pre_quant_scale_and_smooth (module , 1.0 / module .awq_lite .best_scale )
651
+ else :
652
+ warnings .warn (f"awq_lite: Disabling for { name } , quantizing with max calibration." )
653
+ max_calibrate (module , lambda module : module .weight_quantizer (module .weight ))
659
654
660
655
for name , module in model .named_modules ():
661
656
if hasattr (module , "awq_lite" ):
662
- if module .awq_lite .num_cache_steps > 0 :
663
- assert module .awq_lite .num_search_steps > 0 , (
664
- "Calling `forward_loop(model)` the second time did not forward data through the"
665
- " model. Please provide a valid `forward_loop` function that can be used to"
657
+ if module .awq_lite .num_cache_steps == 0 :
658
+ module .awq_lite .is_enabled = False
659
+ elif module .awq_lite .num_search_steps == 0 :
660
+ module .awq_lite .is_enabled = False
661
+ warnings .warn (
662
+ "awq_lite: Calling `forward_loop(model)` the second time did not forward data through the"
663
+ f" { name } . Please provide a valid `forward_loop` function that can be used to"
666
664
" forward data through the model many times."
667
665
)
668
- with enable_weight_access_and_writeback (module , model ):
669
- postprocess (module )
666
+ with enable_weight_access_and_writeback (module , model ):
667
+ postprocess (module , name )
670
668
669
+ module .awq_lite .cleanup ()
671
670
if not debug :
672
671
delattr (module , "awq_lite" )
673
- if hasattr (module , "_if_calib" ):
674
- delattr (module , "_if_calib" )
675
-
676
- unpatch_forward_method (module , "_forward_no_awq" )
677
672
678
673
679
674
@torch .no_grad ()
@@ -858,7 +853,7 @@ def forward(name, self, input, *args, **kwargs):
858
853
with enable_weight_access_and_writeback (module , model ):
859
854
module .awq_clip = AWQClipHelper (module )
860
855
861
- print ( "Estimating awq_clip parameters..." )
856
+ print_rank_0 ( " awq_clip: Estimating parameters..." )
862
857
# Lets enable stats collection
863
858
# This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
864
859
enable_stats_collection (model )
@@ -919,7 +914,7 @@ def svdquant(
919
914
"""
920
915
921
916
def postprocess (module , name ):
922
- print (f"SVD { name } " )
917
+ print_rank_0 (f"SVD { name } " )
923
918
u , s , vt = torch .linalg .svd (module .weight .data .double ())
924
919
if u .shape [1 ] < lowrank or vt .shape [0 ] < lowrank :
925
920
warnings .warn (
0 commit comments