2525import torch .nn .functional as F
2626
2727from modelopt .torch .opt .searcher import ForwardLoop
28+ from modelopt .torch .utils import print_rank_0
2829from modelopt .torch .utils .distributed import ParallelState
2930from modelopt .torch .utils .network import bind_forward_method , unpatch_forward_method
3031
@@ -368,13 +369,13 @@ def postprocess(module):
368369 for name , module in model .named_modules ():
369370 if is_quantized_linear (module ):
370371 if not hasattr (module .input_quantizer , "_amax" ):
371- print (f"Warning: { name } is not calibrated, skip smoothing" )
372+ warnings . warn (f"{ name } is not calibrated, skip smoothing" )
372373 continue
373374 if module .input_quantizer .num_bits != 8 or module .weight_quantizer .num_bits != 8 :
374- print (f"Warning: only int8 smoothing is supported, skip { name } " )
375+ warnings . warn (f"Only int8 smoothing is supported, skip { name } " )
375376 continue
376377 if module .input_quantizer .axis != - 1 :
377- print (f"Warning: only per-channel smoothing is supported, skip { name } " )
378+ warnings . warn (f"Only per-channel smoothing is supported, skip { name } " )
378379 continue
379380
380381 assert module .input_quantizer ._amax .numel () > 1 , (
@@ -385,52 +386,7 @@ def postprocess(module):
385386 postprocess (module )
386387
387388 smoothed_modules += 1
388- print (f"Smoothed { smoothed_modules } modules" )
389-
390-
391- def _smoothquant_fasteval (model : nn .Module ):
392- """Hacky implementation of Smooth-Quant. Copied from monkey-quant."""
393- smoothed_modules = 0
394- for name , module in model .named_modules ():
395- if is_quantized_linear (module ):
396- if not hasattr (module .input_quantizer , "_amax" ):
397- print (f"Warning: { name } is not calibrated, skip smoothing" )
398- continue
399- if module .input_quantizer .num_bits != 8 or module .weight_quantizer .num_bits != 8 :
400- print (f"Warning: only int8 smoothing is supported, skip { name } " )
401- continue
402- if module .input_quantizer .axis != - 1 :
403- print (f"Warning: only per-channel smoothing is supported, skip { name } " )
404- continue
405-
406- assert module .input_quantizer ._amax .numel () > 1
407- delattr (module .weight_quantizer , "_amax" )
408-
409- # It is important to keep scaling math in fp32 to be numerically safe
410- act_amax = module .input_quantizer .amax .float ()
411- if act_amax .shape [0 ] == 1 :
412- act_amax = act_amax .squeeze (0 )
413- # If model is split across devices, this tensor may be on wrong one
414- act_amax = act_amax .to (module .weight .device )
415-
416- max_bound = module .input_quantizer .maxbound
417- scale_a = max_bound / act_amax
418- # Some channel could have 0 amax which causes scale_a to overflow. Explicitly mask them out here
419- epsilon = 1.0 / (1 << 31 )
420- if act_amax .min () <= epsilon :
421- zero_mask = act_amax <= epsilon
422- scale_a [zero_mask ] = 1
423- inv_scale_a = act_amax / max_bound
424-
425- module .weight .data .copy_ (
426- (module .weight_quantizer (inv_scale_a * module .weight .float ()) * scale_a ).to (
427- module .weight .dtype
428- )
429- )
430- module .weight_quantizer .disable ()
431-
432- smoothed_modules += 1
433- print (f"Smoothed { smoothed_modules } modules" )
389+ print_rank_0 (f"Smoothed { smoothed_modules } modules" )
434390
435391
436392def awq (
@@ -481,7 +437,9 @@ def awq_lite(
481437 See :class:`AWQLiteCalibConfig <modelopt.torch.quantization.config.AWQLiteCalibConfig>` for
482438 details on the remaining arguments.
483439 """
484- assert forward_loop is not None , "forward_loop must be provided for awq_lite"
440+ if forward_loop is None :
441+ warnings .warn ("forward_loop must be provided for awq_lite; skipping awq_lite" )
442+ return
485443
486444 class AWQLiteHelper :
487445 cache_mode : bool = False
@@ -498,6 +456,24 @@ def __init__(self, module, name):
498456 self .best_alpha = None
499457 self .is_input_quantized = module .input_quantizer .is_enabled
500458 self .num_tokens = 0
459+ self .module = module
460+ self .is_enabled = True
461+
462+ def setup (self ):
463+ module = self .module
464+ bind_forward_method (module , forward , "_forward_no_awq" )
465+ if module .input_quantizer .is_enabled :
466+ module .input_quantizer .disable ()
467+ if module .input_quantizer .axis not in [None , - 1 ]:
468+ self .is_enabled = False
469+ return
470+ module .input_quantizer .axis = - 1
471+
472+ def cleanup (self ):
473+ module = self .module
474+ if hasattr (module , "_if_calib" ):
475+ delattr (module , "_if_calib" )
476+ unpatch_forward_method (module , "_forward_no_awq" )
501477
502478 def get_weight_scale (weight , block_size = None ):
503479 org_shape = weight .shape
@@ -538,6 +514,8 @@ def update_loss(self, out, out_actual, alpha):
538514 self .awq_lite .loss [alpha ] += loss
539515
540516 def update_best_params (self ):
517+ if not module .awq_lite .is_enabled :
518+ return
541519 self .awq_lite .best_alpha = min (self .awq_lite .loss , key = self .awq_lite .loss .get )
542520 self .awq_lite .best_scale = get_scale (
543521 self .awq_lite .act_scale ,
@@ -560,7 +538,8 @@ def forward(self, input, *args, **kwargs):
560538 out_actual = self ._forward_no_awq (input , * args , ** kwargs )
561539 self .weight_quantizer .enable ()
562540
563- if input .numel () == 0 : # For MoEs, some experts might see 0 tokens
541+ if input .numel () == 0 or not self .awq_lite .is_enabled :
542+ # For MoEs, some experts might see 0 tokens
564543 return out_actual
565544
566545 if AWQLiteHelper .cache_mode :
@@ -588,6 +567,23 @@ def forward(self, input, *args, **kwargs):
588567 )
589568 self .input_quantizer .pre_quant_scale = (1 / awq_scale ).to (self .weight .dtype )
590569 self .weight_quantizer .pre_quant_scale = awq_scale .to (self .weight .dtype )
570+
571+ disable_awq = False
572+ for tq in [self .input_quantizer , self .weight_quantizer ]:
573+ for attr in ["_pre_quant_scale" , "_amax" ]:
574+ if not tq .validate_attr (attr_name = attr ):
575+ disable_awq = True
576+ warnings .warn (
577+ f"awq_lite: { attr } is not valid for { self .awq_lite .name } , skipping awq_lite"
578+ )
579+ break
580+ if disable_awq :
581+ break
582+
583+ if disable_awq :
584+ self .awq_lite .is_enabled = False
585+ return out_actual
586+
591587 out = self ._forward_no_awq (input , * args , ** kwargs )
592588
593589 update_loss (self , out , out_actual , alpha )
@@ -601,19 +597,11 @@ def forward(self, input, *args, **kwargs):
601597 if is_quantized_linear (module ) and module .weight_quantizer .is_enabled :
602598 with enable_weight_access_and_writeback (module , model ):
603599 module .awq_lite = AWQLiteHelper (module , name )
604- bind_forward_method (module , forward , "_forward_no_awq" )
605-
606- if module .input_quantizer .is_enabled :
607- module .input_quantizer .disable ()
608- if module .input_quantizer .axis not in [None , - 1 ]:
609- raise NotImplementedError (
610- "input quantization needs to be per-tensor or None for AWQ algorithm"
611- )
612- module .input_quantizer .axis = - 1
600+ module .awq_lite .setup ()
613601
614602 # Collect activation scale values
615603 AWQLiteHelper .cache_mode = True
616- print ( " Caching activation statistics for awq_lite ..." )
604+ print_rank_0 ( "awq_lite: Caching activation statistics..." )
617605
618606 # Lets enable stats collection
619607 # This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
@@ -635,18 +623,17 @@ def forward(self, input, *args, **kwargs):
635623 module ._if_calib = True
636624
637625 AWQLiteHelper .cache_mode = False
638- print ( "Searching awq_lite parameters..." )
626+ print_rank_0 ( " awq_lite: Searching parameters..." )
639627 with torch .no_grad ():
640628 forward_loop (model )
641629
642- def postprocess (module ):
630+ def postprocess (module , name ):
643631 update_best_params (module )
644632 if hasattr (module .weight_quantizer , "_pre_quant_scale" ):
645633 delattr (module .weight_quantizer , "_pre_quant_scale" )
646634 if hasattr (module .input_quantizer , "_pre_quant_scale" ):
647635 delattr (module .input_quantizer , "_pre_quant_scale" )
648- if module .awq_lite .is_input_quantized :
649- assert module .input_quantizer .amax is not None
636+ if module .awq_lite .is_input_quantized and module .input_quantizer .amax is not None :
650637 act_amax = module .input_quantizer .amax
651638 # TODO: make this a buffer after we support only heterogeneous checkpointing for MCore
652639 module .input_quantizer ._amax_for_smoothing = act_amax .cpu ()
@@ -655,25 +642,29 @@ def postprocess(module):
655642 module .input_quantizer .amax = act_amax .amax ()
656643 module .input_quantizer .enable ()
657644
658- apply_pre_quant_scale_and_smooth (module , 1.0 / module .awq_lite .best_scale )
645+ if module .awq_lite .is_enabled :
646+ apply_pre_quant_scale_and_smooth (module , 1.0 / module .awq_lite .best_scale )
647+ else :
648+ warnings .warn (f"awq_lite: Disabling for { name } , quantizing with max calibration." )
649+ max_calibrate (module , lambda module : module .weight_quantizer (module .weight ))
659650
660651 for name , module in model .named_modules ():
661652 if hasattr (module , "awq_lite" ):
662- if module .awq_lite .num_cache_steps > 0 :
663- assert module .awq_lite .num_search_steps > 0 , (
664- "Calling `forward_loop(model)` the second time did not forward data through the"
665- " model. Please provide a valid `forward_loop` function that can be used to"
653+ if module .awq_lite .num_cache_steps == 0 :
654+ module .awq_lite .is_enabled = False
655+ elif module .awq_lite .num_search_steps == 0 :
656+ module .awq_lite .is_enabled = False
657+ warnings .warn (
658+ "awq_lite: Calling `forward_loop(model)` the second time did not forward data through the"
659+ f" { name } . Please provide a valid `forward_loop` function that can be used to"
666660 " forward data through the model many times."
667661 )
668- with enable_weight_access_and_writeback (module , model ):
669- postprocess (module )
662+ with enable_weight_access_and_writeback (module , model ):
663+ postprocess (module , name )
670664
665+ module .awq_lite .cleanup ()
671666 if not debug :
672667 delattr (module , "awq_lite" )
673- if hasattr (module , "_if_calib" ):
674- delattr (module , "_if_calib" )
675-
676- unpatch_forward_method (module , "_forward_no_awq" )
677668
678669
679670@torch .no_grad ()
@@ -858,7 +849,7 @@ def forward(name, self, input, *args, **kwargs):
858849 with enable_weight_access_and_writeback (module , model ):
859850 module .awq_clip = AWQClipHelper (module )
860851
861- print ( "Estimating awq_clip parameters..." )
852+ print_rank_0 ( " awq_clip: Estimating parameters..." )
862853 # Lets enable stats collection
863854 # This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
864855 enable_stats_collection (model )
@@ -919,7 +910,7 @@ def svdquant(
919910 """
920911
921912 def postprocess (module , name ):
922- print (f"SVD { name } " )
913+ print_rank_0 (f"SVD { name } " )
923914 u , s , vt = torch .linalg .svd (module .weight .data .double ())
924915 if u .shape [1 ] < lowrank or vt .shape [0 ] < lowrank :
925916 warnings .warn (
0 commit comments