2525import torch .nn .functional as F
2626
2727from modelopt .torch .opt .searcher import ForwardLoop
28+ from modelopt .torch .utils import print_rank_0
2829from modelopt .torch .utils .distributed import ParallelState
2930from modelopt .torch .utils .network import bind_forward_method , unpatch_forward_method
3031
@@ -368,13 +369,13 @@ def postprocess(module):
368369 for name , module in model .named_modules ():
369370 if is_quantized_linear (module ):
370371 if not hasattr (module .input_quantizer , "_amax" ):
371- print (f"Warning: { name } is not calibrated, skip smoothing" )
372+ warnings . warn (f"{ name } is not calibrated, skip smoothing" )
372373 continue
373374 if module .input_quantizer .num_bits != 8 or module .weight_quantizer .num_bits != 8 :
374- print (f"Warning: only int8 smoothing is supported, skip { name } " )
375+ warnings . warn (f"Only int8 smoothing is supported, skip { name } " )
375376 continue
376377 if module .input_quantizer .axis != - 1 :
377- print (f"Warning: only per-channel smoothing is supported, skip { name } " )
378+ warnings . warn (f"Only per-channel smoothing is supported, skip { name } " )
378379 continue
379380
380381 assert module .input_quantizer ._amax .numel () > 1 , (
@@ -385,52 +386,7 @@ def postprocess(module):
385386 postprocess (module )
386387
387388 smoothed_modules += 1
388- print (f"Smoothed { smoothed_modules } modules" )
389-
390-
391- def _smoothquant_fasteval (model : nn .Module ):
392- """Hacky implementation of Smooth-Quant. Copied from monkey-quant."""
393- smoothed_modules = 0
394- for name , module in model .named_modules ():
395- if is_quantized_linear (module ):
396- if not hasattr (module .input_quantizer , "_amax" ):
397- print (f"Warning: { name } is not calibrated, skip smoothing" )
398- continue
399- if module .input_quantizer .num_bits != 8 or module .weight_quantizer .num_bits != 8 :
400- print (f"Warning: only int8 smoothing is supported, skip { name } " )
401- continue
402- if module .input_quantizer .axis != - 1 :
403- print (f"Warning: only per-channel smoothing is supported, skip { name } " )
404- continue
405-
406- assert module .input_quantizer ._amax .numel () > 1
407- delattr (module .weight_quantizer , "_amax" )
408-
409- # It is important to keep scaling math in fp32 to be numerically safe
410- act_amax = module .input_quantizer .amax .float ()
411- if act_amax .shape [0 ] == 1 :
412- act_amax = act_amax .squeeze (0 )
413- # If model is split across devices, this tensor may be on wrong one
414- act_amax = act_amax .to (module .weight .device )
415-
416- max_bound = module .input_quantizer .maxbound
417- scale_a = max_bound / act_amax
418- # Some channel could have 0 amax which causes scale_a to overflow. Explicitly mask them out here
419- epsilon = 1.0 / (1 << 31 )
420- if act_amax .min () <= epsilon :
421- zero_mask = act_amax <= epsilon
422- scale_a [zero_mask ] = 1
423- inv_scale_a = act_amax / max_bound
424-
425- module .weight .data .copy_ (
426- (module .weight_quantizer (inv_scale_a * module .weight .float ()) * scale_a ).to (
427- module .weight .dtype
428- )
429- )
430- module .weight_quantizer .disable ()
431-
432- smoothed_modules += 1
433- print (f"Smoothed { smoothed_modules } modules" )
389+ print_rank_0 (f"Smoothed { smoothed_modules } modules" )
434390
435391
436392def awq (
@@ -481,7 +437,9 @@ def awq_lite(
481437 See :class:`AWQLiteCalibConfig <modelopt.torch.quantization.config.AWQLiteCalibConfig>` for
482438 details on the remaining arguments.
483439 """
484- assert forward_loop is not None , "forward_loop must be provided for awq_lite"
440+ if forward_loop is None :
441+ warnings .warn ("forward_loop must be provided for awq_lite; skipping awq_lite" )
442+ return
485443
486444 class AWQLiteHelper :
487445 cache_mode : bool = False
@@ -493,11 +451,32 @@ def __init__(self, module, name):
493451 self .num_search_steps = 0
494452 self .block_size = _get_awq_quantizer_block_size (module .weight , module .weight_quantizer )
495453 self .weight_scale = get_weight_scale (module .weight , self .block_size )
496- self .loss = {k .item (): 0.0 for k in torch .arange (0 , 1.0 + alpha_step , alpha_step )}
454+ self .loss = {
455+ k .item (): torch .zeros ((), device = module .weight .device , dtype = torch .float32 )
456+ for k in torch .arange (0 , 1.0 + alpha_step , alpha_step )
457+ }
497458 self .best_scale = None
498459 self .best_alpha = None
499460 self .is_input_quantized = module .input_quantizer .is_enabled
500461 self .num_tokens = 0
462+ self .module = module
463+ self .is_enabled = True
464+
465+ def setup (self ):
466+ module = self .module
467+ bind_forward_method (module , forward , "_forward_no_awq" )
468+ if module .input_quantizer .is_enabled :
469+ module .input_quantizer .disable ()
470+ if module .input_quantizer .axis not in [None , - 1 ]:
471+ self .is_enabled = False
472+ return
473+ module .input_quantizer .axis = - 1
474+
475+ def cleanup (self ):
476+ module = self .module
477+ if hasattr (module , "_if_calib" ):
478+ delattr (module , "_if_calib" )
479+ unpatch_forward_method (module , "_forward_no_awq" )
501480
502481 def get_weight_scale (weight , block_size = None ):
503482 org_shape = weight .shape
@@ -534,10 +513,13 @@ def get_scale(x_max, w_max, alpha, tensor_parallel_group=None):
534513 def update_loss (self , out , out_actual , alpha ):
535514 out_actual = out_actual [0 ] if isinstance (out_actual , tuple ) else out_actual
536515 out = out [0 ] if isinstance (out , tuple ) else out
537- loss = (out - out_actual ).float ().pow (2 ).mean (). item ()
516+ loss = (out - out_actual ).float ().pow (2 ).mean ()
538517 self .awq_lite .loss [alpha ] += loss
539518
540519 def update_best_params (self ):
520+ if not self .awq_lite .is_enabled :
521+ return
522+ self .awq_lite .loss .update ({k : float (v ) for k , v in self .awq_lite .loss .items ()})
541523 self .awq_lite .best_alpha = min (self .awq_lite .loss , key = self .awq_lite .loss .get )
542524 self .awq_lite .best_scale = get_scale (
543525 self .awq_lite .act_scale ,
@@ -560,7 +542,8 @@ def forward(self, input, *args, **kwargs):
560542 out_actual = self ._forward_no_awq (input , * args , ** kwargs )
561543 self .weight_quantizer .enable ()
562544
563- if input .numel () == 0 : # For MoEs, some experts might see 0 tokens
545+ if input .numel () == 0 or not self .awq_lite .is_enabled :
546+ # For MoEs, some experts might see 0 tokens
564547 return out_actual
565548
566549 if AWQLiteHelper .cache_mode :
@@ -588,6 +571,23 @@ def forward(self, input, *args, **kwargs):
588571 )
589572 self .input_quantizer .pre_quant_scale = (1 / awq_scale ).to (self .weight .dtype )
590573 self .weight_quantizer .pre_quant_scale = awq_scale .to (self .weight .dtype )
574+
575+ disable_awq = False
576+ for tq in [self .input_quantizer , self .weight_quantizer ]:
577+ for attr in ["_pre_quant_scale" , "_amax" ]:
578+ if not tq .validate_attr (attr_name = attr ):
579+ disable_awq = True
580+ warnings .warn (
581+ f"awq_lite: { attr } is not valid for { self .awq_lite .name } , skipping awq_lite"
582+ )
583+ break
584+ if disable_awq :
585+ break
586+
587+ if disable_awq :
588+ self .awq_lite .is_enabled = False
589+ return out_actual
590+
591591 out = self ._forward_no_awq (input , * args , ** kwargs )
592592
593593 update_loss (self , out , out_actual , alpha )
@@ -601,19 +601,11 @@ def forward(self, input, *args, **kwargs):
601601 if is_quantized_linear (module ) and module .weight_quantizer .is_enabled :
602602 with enable_weight_access_and_writeback (module , model ):
603603 module .awq_lite = AWQLiteHelper (module , name )
604- bind_forward_method (module , forward , "_forward_no_awq" )
605-
606- if module .input_quantizer .is_enabled :
607- module .input_quantizer .disable ()
608- if module .input_quantizer .axis not in [None , - 1 ]:
609- raise NotImplementedError (
610- "input quantization needs to be per-tensor or None for AWQ algorithm"
611- )
612- module .input_quantizer .axis = - 1
604+ module .awq_lite .setup ()
613605
614606 # Collect activation scale values
615607 AWQLiteHelper .cache_mode = True
616- print ( " Caching activation statistics for awq_lite ..." )
608+ print_rank_0 ( "awq_lite: Caching activation statistics..." )
617609
618610 # Lets enable stats collection
619611 # This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
@@ -635,18 +627,17 @@ def forward(self, input, *args, **kwargs):
635627 module ._if_calib = True
636628
637629 AWQLiteHelper .cache_mode = False
638- print ( "Searching awq_lite parameters..." )
630+ print_rank_0 ( " awq_lite: Searching parameters..." )
639631 with torch .no_grad ():
640632 forward_loop (model )
641633
642- def postprocess (module ):
634+ def postprocess (module , name ):
643635 update_best_params (module )
644636 if hasattr (module .weight_quantizer , "_pre_quant_scale" ):
645637 delattr (module .weight_quantizer , "_pre_quant_scale" )
646638 if hasattr (module .input_quantizer , "_pre_quant_scale" ):
647639 delattr (module .input_quantizer , "_pre_quant_scale" )
648- if module .awq_lite .is_input_quantized :
649- assert module .input_quantizer .amax is not None
640+ if module .awq_lite .is_input_quantized and module .input_quantizer .amax is not None :
650641 act_amax = module .input_quantizer .amax
651642 # TODO: make this a buffer after we support only heterogeneous checkpointing for MCore
652643 module .input_quantizer ._amax_for_smoothing = act_amax .cpu ()
@@ -655,25 +646,29 @@ def postprocess(module):
655646 module .input_quantizer .amax = act_amax .amax ()
656647 module .input_quantizer .enable ()
657648
658- apply_pre_quant_scale_and_smooth (module , 1.0 / module .awq_lite .best_scale )
649+ if module .awq_lite .is_enabled :
650+ apply_pre_quant_scale_and_smooth (module , 1.0 / module .awq_lite .best_scale )
651+ else :
652+ warnings .warn (f"awq_lite: Disabling for { name } , quantizing with max calibration." )
653+ max_calibrate (module , lambda module : module .weight_quantizer (module .weight ))
659654
660655 for name , module in model .named_modules ():
661656 if hasattr (module , "awq_lite" ):
662- if module .awq_lite .num_cache_steps > 0 :
663- assert module .awq_lite .num_search_steps > 0 , (
664- "Calling `forward_loop(model)` the second time did not forward data through the"
665- " model. Please provide a valid `forward_loop` function that can be used to"
657+ if module .awq_lite .num_cache_steps == 0 :
658+ module .awq_lite .is_enabled = False
659+ elif module .awq_lite .num_search_steps == 0 :
660+ module .awq_lite .is_enabled = False
661+ warnings .warn (
662+ "awq_lite: Calling `forward_loop(model)` the second time did not forward data through the"
663+ f" { name } . Please provide a valid `forward_loop` function that can be used to"
666664 " forward data through the model many times."
667665 )
668- with enable_weight_access_and_writeback (module , model ):
669- postprocess (module )
666+ with enable_weight_access_and_writeback (module , model ):
667+ postprocess (module , name )
670668
669+ module .awq_lite .cleanup ()
671670 if not debug :
672671 delattr (module , "awq_lite" )
673- if hasattr (module , "_if_calib" ):
674- delattr (module , "_if_calib" )
675-
676- unpatch_forward_method (module , "_forward_no_awq" )
677672
678673
679674@torch .no_grad ()
@@ -858,7 +853,7 @@ def forward(name, self, input, *args, **kwargs):
858853 with enable_weight_access_and_writeback (module , model ):
859854 module .awq_clip = AWQClipHelper (module )
860855
861- print ( "Estimating awq_clip parameters..." )
856+ print_rank_0 ( " awq_clip: Estimating parameters..." )
862857 # Lets enable stats collection
863858 # This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
864859 enable_stats_collection (model )
@@ -919,7 +914,7 @@ def svdquant(
919914 """
920915
921916 def postprocess (module , name ):
922- print (f"SVD { name } " )
917+ print_rank_0 (f"SVD { name } " )
923918 u , s , vt = torch .linalg .svd (module .weight .data .double ())
924919 if u .shape [1 ] < lowrank or vt .shape [0 ] < lowrank :
925920 warnings .warn (
0 commit comments