2525import torch .nn .functional as F
2626
2727from modelopt .torch .opt .searcher import ForwardLoop
28+ from modelopt .torch .utils import print_rank_0
2829from modelopt .torch .utils .distributed import ParallelState
2930from modelopt .torch .utils .network import bind_forward_method , unpatch_forward_method
3031
@@ -368,13 +369,13 @@ def postprocess(module):
368369 for name , module in model .named_modules ():
369370 if is_quantized_linear (module ):
370371 if not hasattr (module .input_quantizer , "_amax" ):
371- print (f"Warning: { name } is not calibrated, skip smoothing" )
372+ warnings . warn (f"{ name } is not calibrated, skip smoothing" )
372373 continue
373374 if module .input_quantizer .num_bits != 8 or module .weight_quantizer .num_bits != 8 :
374- print (f"Warning: only int8 smoothing is supported, skip { name } " )
375+ warnings . warn (f"Only int8 smoothing is supported, skip { name } " )
375376 continue
376377 if module .input_quantizer .axis != - 1 :
377- print (f"Warning: only per-channel smoothing is supported, skip { name } " )
378+ warnings . warn (f"Only per-channel smoothing is supported, skip { name } " )
378379 continue
379380
380381 assert module .input_quantizer ._amax .numel () > 1 , (
@@ -385,52 +386,7 @@ def postprocess(module):
385386 postprocess (module )
386387
387388 smoothed_modules += 1
388- print (f"Smoothed { smoothed_modules } modules" )
389-
390-
391- def _smoothquant_fasteval (model : nn .Module ):
392- """Hacky implementation of Smooth-Quant. Copied from monkey-quant."""
393- smoothed_modules = 0
394- for name , module in model .named_modules ():
395- if is_quantized_linear (module ):
396- if not hasattr (module .input_quantizer , "_amax" ):
397- print (f"Warning: { name } is not calibrated, skip smoothing" )
398- continue
399- if module .input_quantizer .num_bits != 8 or module .weight_quantizer .num_bits != 8 :
400- print (f"Warning: only int8 smoothing is supported, skip { name } " )
401- continue
402- if module .input_quantizer .axis != - 1 :
403- print (f"Warning: only per-channel smoothing is supported, skip { name } " )
404- continue
405-
406- assert module .input_quantizer ._amax .numel () > 1
407- delattr (module .weight_quantizer , "_amax" )
408-
409- # It is important to keep scaling math in fp32 to be numerically safe
410- act_amax = module .input_quantizer .amax .float ()
411- if act_amax .shape [0 ] == 1 :
412- act_amax = act_amax .squeeze (0 )
413- # If model is split across devices, this tensor may be on wrong one
414- act_amax = act_amax .to (module .weight .device )
415-
416- max_bound = module .input_quantizer .maxbound
417- scale_a = max_bound / act_amax
418- # Some channel could have 0 amax which causes scale_a to overflow. Explicitly mask them out here
419- epsilon = 1.0 / (1 << 31 )
420- if act_amax .min () <= epsilon :
421- zero_mask = act_amax <= epsilon
422- scale_a [zero_mask ] = 1
423- inv_scale_a = act_amax / max_bound
424-
425- module .weight .data .copy_ (
426- (module .weight_quantizer (inv_scale_a * module .weight .float ()) * scale_a ).to (
427- module .weight .dtype
428- )
429- )
430- module .weight_quantizer .disable ()
431-
432- smoothed_modules += 1
433- print (f"Smoothed { smoothed_modules } modules" )
389+ print_rank_0 (f"Smoothed { smoothed_modules } modules" )
434390
435391
436392def awq (
@@ -481,7 +437,9 @@ def awq_lite(
481437 See :class:`AWQLiteCalibConfig <modelopt.torch.quantization.config.AWQLiteCalibConfig>` for
482438 details on the remaining arguments.
483439 """
484- assert forward_loop is not None , "forward_loop must be provided for awq_lite"
440+ if forward_loop is None :
441+ warnings .warn ("forward_loop must be provided for awq_lite; skipping awq_lite" )
442+ return
485443
486444 class AWQLiteHelper :
487445 cache_mode : bool = False
@@ -493,11 +451,32 @@ def __init__(self, module, name):
493451 self .num_search_steps = 0
494452 self .block_size = _get_awq_quantizer_block_size (module .weight , module .weight_quantizer )
495453 self .weight_scale = get_weight_scale (module .weight , self .block_size )
496- self .loss = {k .item (): 0.0 for k in torch .arange (0 , 1.0 + alpha_step , alpha_step )}
454+ self .loss = {
455+ k .item (): torch .zeros ((), device = module .weight .device , dtype = torch .float32 )
456+ for k in torch .arange (0 , 1.0 + alpha_step , alpha_step )
457+ }
497458 self .best_scale = None
498459 self .best_alpha = None
499460 self .is_input_quantized = module .input_quantizer .is_enabled
500461 self .num_tokens = 0
462+ self .module = module
463+ self .is_enabled = True
464+
465+ def setup (self ):
466+ module = self .module
467+ bind_forward_method (module , forward , "_forward_no_awq" )
468+ if module .input_quantizer .is_enabled :
469+ module .input_quantizer .disable ()
470+ if module .input_quantizer .axis not in [None , - 1 ]:
471+ self .is_enabled = False
472+ return
473+ module .input_quantizer .axis = - 1
474+
475+ def cleanup (self ):
476+ module = self .module
477+ if hasattr (module , "_if_calib" ):
478+ delattr (module , "_if_calib" )
479+ unpatch_forward_method (module , "_forward_no_awq" )
501480
502481 def get_weight_scale (weight , block_size = None ):
503482 org_shape = weight .shape
@@ -534,10 +513,13 @@ def get_scale(x_max, w_max, alpha, tensor_parallel_group=None):
534513 def update_loss (self , out , out_actual , alpha ):
535514 out_actual = out_actual [0 ] if isinstance (out_actual , tuple ) else out_actual
536515 out = out [0 ] if isinstance (out , tuple ) else out
537- loss = (out - out_actual ).float ().pow (2 ).mean (). item ()
516+ loss = (out - out_actual ).float ().pow (2 ).mean ()
538517 self .awq_lite .loss [alpha ] += loss
539518
540519 def update_best_params (self ):
520+ if not self .awq_lite .is_enabled :
521+ return
522+ self .awq_lite .loss .update ({k : float (v ) for k , v in self .awq_lite .loss .items ()})
541523 self .awq_lite .best_alpha = min (self .awq_lite .loss , key = self .awq_lite .loss .get )
542524 self .awq_lite .best_scale = get_scale (
543525 self .awq_lite .act_scale ,
@@ -560,7 +542,8 @@ def forward(self, input, *args, **kwargs):
560542 out_actual = self ._forward_no_awq (input , * args , ** kwargs )
561543 self .weight_quantizer .enable ()
562544
563- if input .numel () == 0 : # For MoEs, some experts might see 0 tokens
545+ if input .numel () == 0 or not self .awq_lite .is_enabled :
546+ # For MoEs, some experts might see 0 tokens
564547 return out_actual
565548
566549 if AWQLiteHelper .cache_mode :
@@ -589,7 +572,6 @@ def forward(self, input, *args, **kwargs):
589572 self .input_quantizer .pre_quant_scale = (1 / awq_scale ).to (self .weight .dtype )
590573 self .weight_quantizer .pre_quant_scale = awq_scale .to (self .weight .dtype )
591574 out = self ._forward_no_awq (input , * args , ** kwargs )
592-
593575 update_loss (self , out , out_actual , alpha )
594576
595577 self .awq_lite .num_search_steps += 1
@@ -601,19 +583,11 @@ def forward(self, input, *args, **kwargs):
601583 if is_quantized_linear (module ) and module .weight_quantizer .is_enabled :
602584 with enable_weight_access_and_writeback (module , model ):
603585 module .awq_lite = AWQLiteHelper (module , name )
604- bind_forward_method (module , forward , "_forward_no_awq" )
605-
606- if module .input_quantizer .is_enabled :
607- module .input_quantizer .disable ()
608- if module .input_quantizer .axis not in [None , - 1 ]:
609- raise NotImplementedError (
610- "input quantization needs to be per-tensor or None for AWQ algorithm"
611- )
612- module .input_quantizer .axis = - 1
586+ module .awq_lite .setup ()
613587
614588 # Collect activation scale values
615589 AWQLiteHelper .cache_mode = True
616- print ( " Caching activation statistics for awq_lite ..." )
590+ print_rank_0 ( "awq_lite: Caching activation statistics..." )
617591
618592 # Lets enable stats collection
619593 # This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
@@ -631,22 +605,25 @@ def forward(self, input, *args, **kwargs):
631605 and module .awq_lite .num_cache_steps > 0
632606 ):
633607 module .awq_lite .act_scale = module .awq_lite .act_scale / module .awq_lite .num_cache_steps
608+ if torch .any (torch .isnan (module .awq_lite .act_scale )) or torch .any (
609+ torch .isnan (module .awq_lite .weight_scale )
610+ ):
611+ module .awq_lite .is_enabled = False
634612 # Hack: MoEs forward all tokens through all experts if _if_calib is True
635613 module ._if_calib = True
636614
637615 AWQLiteHelper .cache_mode = False
638- print ( "Searching awq_lite parameters..." )
616+ print_rank_0 ( " awq_lite: Searching parameters..." )
639617 with torch .no_grad ():
640618 forward_loop (model )
641619
642- def postprocess (module ):
620+ def postprocess (module , name ):
643621 update_best_params (module )
644622 if hasattr (module .weight_quantizer , "_pre_quant_scale" ):
645623 delattr (module .weight_quantizer , "_pre_quant_scale" )
646624 if hasattr (module .input_quantizer , "_pre_quant_scale" ):
647625 delattr (module .input_quantizer , "_pre_quant_scale" )
648- if module .awq_lite .is_input_quantized :
649- assert module .input_quantizer .amax is not None
626+ if module .awq_lite .is_input_quantized and module .input_quantizer .amax is not None :
650627 act_amax = module .input_quantizer .amax
651628 # TODO: make this a buffer after we support only heterogeneous checkpointing for MCore
652629 module .input_quantizer ._amax_for_smoothing = act_amax .cpu ()
@@ -655,25 +632,29 @@ def postprocess(module):
655632 module .input_quantizer .amax = act_amax .amax ()
656633 module .input_quantizer .enable ()
657634
658- apply_pre_quant_scale_and_smooth (module , 1.0 / module .awq_lite .best_scale )
635+ if module .awq_lite .is_enabled :
636+ apply_pre_quant_scale_and_smooth (module , 1.0 / module .awq_lite .best_scale )
637+ else :
638+ warnings .warn (f"awq_lite: Disabling for { name } , quantizing with max calibration." )
639+ max_calibrate (module , lambda module : module .weight_quantizer (module .weight ))
659640
660641 for name , module in model .named_modules ():
661642 if hasattr (module , "awq_lite" ):
662- if module .awq_lite .num_cache_steps > 0 :
663- assert module .awq_lite .num_search_steps > 0 , (
664- "Calling `forward_loop(model)` the second time did not forward data through the"
665- " model. Please provide a valid `forward_loop` function that can be used to"
643+ if module .awq_lite .num_cache_steps == 0 :
644+ module .awq_lite .is_enabled = False
645+ elif module .awq_lite .num_search_steps == 0 :
646+ module .awq_lite .is_enabled = False
647+ warnings .warn (
648+ "awq_lite: Calling `forward_loop(model)` the second time did not forward data through the"
649+ f" { name } . Please provide a valid `forward_loop` function that can be used to"
666650 " forward data through the model many times."
667651 )
668- with enable_weight_access_and_writeback (module , model ):
669- postprocess (module )
652+ with enable_weight_access_and_writeback (module , model ):
653+ postprocess (module , name )
670654
655+ module .awq_lite .cleanup ()
671656 if not debug :
672657 delattr (module , "awq_lite" )
673- if hasattr (module , "_if_calib" ):
674- delattr (module , "_if_calib" )
675-
676- unpatch_forward_method (module , "_forward_no_awq" )
677658
678659
679660@torch .no_grad ()
@@ -858,7 +839,7 @@ def forward(name, self, input, *args, **kwargs):
858839 with enable_weight_access_and_writeback (module , model ):
859840 module .awq_clip = AWQClipHelper (module )
860841
861- print ( "Estimating awq_clip parameters..." )
842+ print_rank_0 ( " awq_clip: Estimating parameters..." )
862843 # Lets enable stats collection
863844 # This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
864845 enable_stats_collection (model )
@@ -919,7 +900,7 @@ def svdquant(
919900 """
920901
921902 def postprocess (module , name ):
922- print (f"SVD { name } " )
903+ print_rank_0 (f"SVD { name } " )
923904 u , s , vt = torch .linalg .svd (module .weight .data .double ())
924905 if u .shape [1 ] < lowrank or vt .shape [0 ] < lowrank :
925906 warnings .warn (
0 commit comments