77import os
88import pprint
99import sys
10+ import warnings
1011
1112import numpy as np
1213import torch
14+ from torch .utils .data import DataLoader
1315from transformers import AutoModelForCausalLM
1416from transformers import AutoTokenizer
1517
4042from brevitas_examples .llm .llm_quant .awq .pre_quant import apply_awq
4143from brevitas_examples .llm .llm_quant .bias_corr import apply_bias_correction
4244from brevitas_examples .llm .llm_quant .calibrate import apply_calibration
45+ from brevitas_examples .llm .llm_quant .data_utils import collate_fn
4346from brevitas_examples .llm .llm_quant .data_utils import get_dataset_for_model
4447from brevitas_examples .llm .llm_quant .equalize import apply_act_equalization
4548from brevitas_examples .llm .llm_quant .equalize import apply_weight_equalization
@@ -92,7 +95,7 @@ def fused_rotation_no_fx(model, calibration_loader, args):
9295 with torch .no_grad (), rmsnorm_patch (model , model .config ) as patcher :
9396 rmsnorm_classes = patcher .rmsnorm_classes
9497 with make_dynamo_compatible (model ) as dynamo_comp :
95- fx_model , guards = torch ._dynamo .export (dynamo_comp .model )(** calibration_loader [ 0 ] )
98+ fx_model , guards = torch ._dynamo .export (dynamo_comp .model )(** next ( iter ( calibration_loader )) )
9699 if hasattr (model , str (torch .nn .functional .scaled_dot_product_attention )):
97100 m_to_add = getattr (model , str (torch .nn .functional .scaled_dot_product_attention ))
98101 fx_model .add_module (str (torch .nn .functional .scaled_dot_product_attention ), m_to_add )
@@ -199,7 +202,7 @@ def model_export(model, tokenizer, ref_input, args, config=None):
199202
200203
201204def fx_required (args ):
202- return True if args .weight_equalization or args .act_equalization == 'fx' or args .rotation == 'fx' or args .ln_affine_merge or args .convert_layernorm_to_rmsnorm or args .quant_sdpa == 'fx' else False
205+ return args .weight_equalization or args .act_equalization == 'fx' or args .rotation == 'fx' or args .ln_affine_merge or args .convert_layernorm_to_rmsnorm or args .quant_sdpa == 'fx'
203206
204207
205208# Recursive function to unwrap equalized layers
@@ -232,9 +235,13 @@ def quantize_llm(args, extra_args=None):
232235 quant_ppl = None
233236
234237 require_fx = fx_required (args )
238+ if require_fx and args .calibration_batch_size > 1 :
239+ warnings .warn (
240+ f"The provided configuration requires fx and has a batch size of { args .calibration_batch_size } .\n Errors may occur when using fx and batch_size > 1.\n If you experience any issues try chaning the configuration to avoid using fx or to set the batch_size to 1."
241+ )
235242
236243 # Load the data for calibration and evaluation.
237- calibration_loader = get_dataset_for_model (
244+ calibration_dataset = get_dataset_for_model (
238245 args .model ,
239246 bos_preprocessing = args .bos_preprocessing ,
240247 dataset_name = args .dataset ,
@@ -246,7 +253,11 @@ def quantize_llm(args, extra_args=None):
246253 require_fx = require_fx and args .export_target is not None ,
247254 device = None )
248255
249- validation_loader = get_dataset_for_model (
256+ # Batched data loader to accelerate GPXQ algorithms
257+ calibration_loader = DataLoader (
258+ dataset = calibration_dataset , batch_size = args .calibration_batch_size , collate_fn = collate_fn )
259+
260+ validation_dataset = get_dataset_for_model (
250261 args .model ,
251262 bos_preprocessing = args .bos_preprocessing ,
252263 dataset_name = args .dataset ,
@@ -262,7 +273,7 @@ def quantize_llm(args, extra_args=None):
262273 # Extra arguments should be used as training arguments for rotation optimization
263274 rot_optimization_args = parse_rotation_optimization_args (extra_args = extra_args )
264275 # Load the data for rotation optimization
265- rot_calibration_loader = get_dataset_for_model (
276+ rot_calibration_dataset = get_dataset_for_model (
266277 args .model ,
267278 bos_preprocessing = args .bos_preprocessing ,
268279 dataset_name = args .dataset ,
@@ -281,7 +292,7 @@ def quantize_llm(args, extra_args=None):
281292 print ("Float model eval..." )
282293 model = offload_model (model )
283294 float_ppl = compute_perplexity (
284- model , validation_loader , context_length = args .seqlen // 2 , tokenizer = tokenizer )
295+ model , validation_dataset , context_length = args .seqlen // 2 , tokenizer = tokenizer )
285296 remove_hooks (model )
286297 print (f"Float perplexity ({ args .dataset } ): { float_ppl :.3f} " )
287298
@@ -290,7 +301,7 @@ def quantize_llm(args, extra_args=None):
290301 with torch .no_grad (), rmsnorm_patch (model , model .config , enabled = args .replace_rmsnorm ) as patcher :
291302 rmsnorm_classes = patcher .rmsnorm_classes
292303 with make_dynamo_compatible (model ) as dynamo_comp :
293- model , guards = torch ._dynamo .export (dynamo_comp . model )(** calibration_loader [ 0 ] )
304+ model , guards = torch ._dynamo .export (model )(** next ( iter ( calibration_loader )) )
294305 # Blockwise optimization does not work with FX at the moment
295306 args .gpxq_block_name = None
296307 model .eval ()
@@ -317,7 +328,7 @@ def quantize_llm(args, extra_args=None):
317328 print ("Inserting SDPA quantizable module" )
318329 model = offload_model (model )
319330 with torch .no_grad (), functional_quantization_mode (model , {torch .nn .functional .scaled_dot_product_attention : ScaledDotProductAttention }):
320- model (** calibration_loader [ 0 ] )
331+ model (** next ( iter ( calibration_loader )) )
321332 remove_hooks (model )
322333 elif args .quant_sdpa == 'eager' :
323334 model = replace_sdpa_with_quantizable_layers (
@@ -365,7 +376,7 @@ def quantize_llm(args, extra_args=None):
365376 offload_model (model )
366377 print (f"Apply act equalization (SmoothQuant) with alpha { args .act_equalization_alpha } " )
367378 if args .load_checkpoint :
368- loader = [calibration_loader [ 0 ] ]
379+ loader = [next ( iter ( calibration_loader )) ]
369380 else :
370381 loader = calibration_loader
371382 apply_act_equalization (
@@ -479,7 +490,7 @@ def quantize_llm(args, extra_args=None):
479490 apply_awq (
480491 model = model ,
481492 tokenizer = tokenizer ,
482- calibration_loader = calibration_loader ,
493+ calibration_dataset = calibration_dataset ,
483494 args = args ,
484495 auto_scale = args .awq_scale ,
485496 mse_range = args .awq_clip ,
@@ -522,7 +533,7 @@ def quantize_llm(args, extra_args=None):
522533 with quantization_cm :
523534 # We initialize weights scale factor
524535 with torch .no_grad ():
525- model (** calibration_loader [ 0 ] )
536+ model (** next ( iter ( calibration_loader )) )
526537
527538 if args .compile_ptq :
528539 for m in model .modules ():
@@ -540,7 +551,7 @@ def quantize_llm(args, extra_args=None):
540551 apply_rotation_optimization (
541552 model = model ,
542553 tokenizer = tokenizer ,
543- train_dataset = rot_calibration_loader ,
554+ train_dataset = rot_calibration_dataset ,
544555 training_args = rot_optimization_args ,
545556 )
546557 # Remove hooks from optimization
@@ -561,18 +572,19 @@ def quantize_llm(args, extra_args=None):
561572 dtype = torch .float32 )
562573 model = offload_model (model )
563574 with torch .no_grad ():
564- model (** calibration_loader [ 0 ] )
575+ model (** next ( iter ( calibration_loader )) )
565576 print ("SVDQuant applied." )
566577
567578 if args .learned_round :
568579 print ("Applying learned round..." )
569580 if args .load_checkpoint :
570581 iters = 1
571- loader = [calibration_loader [0 ]]
582+ loader = [calibration_dataset [0 ]]
572583 else :
573584 iters = args .learned_round_iters
574- loader = calibration_loader
585+ loader = calibration_dataset
575586 remove_hooks (model )
587+ # TODO (pml): Fix learned round type hints
576588 apply_learned_round (
577589 model ,
578590 loader ,
@@ -650,18 +662,17 @@ def quantize_llm(args, extra_args=None):
650662 if args .eval and not args .no_quantize :
651663 print ("Model eval..." )
652664 with torch .no_grad (), quant_inference_mode (model , compile = args .compile_eval ):
653- model (** calibration_loader [ 0 ] )
665+ model (** next ( iter ( calibration_loader )) )
654666 quant_ppl = compute_perplexity (
655- model , validation_loader , context_length = args .seqlen // 2 , tokenizer = tokenizer )
667+ model , validation_dataset , context_length = args .seqlen // 2 , tokenizer = tokenizer )
656668 print (f"Quantized perplexity ({ args .dataset } ): { quant_ppl :.3f} " )
657669 few_shot_eval_results = dict ()
658670 if args .few_shot_eval == 'lm_eval' :
659671 from lm_eval import evaluator
660672 from lm_eval .models .huggingface import HFLM
661673 with torch .no_grad (), quant_inference_mode (model , compile = args .compile_eval ):
662- model (** calibration_loader [ 0 ] )
674+ model (** next ( iter ( calibration_loader )) )
663675 batch_size = 'auto' if args .few_shot_override_batch_size is None else args .few_shot_override_batch_size
664-
665676 wrapped_model = HFLM (
666677 pretrained = model , add_bos_token = True ,
667678 batch_size = batch_size ) # need to wrap for LLM eval
@@ -681,7 +692,7 @@ def quantize_llm(args, extra_args=None):
681692 elif args .few_shot_eval == 'lighteval' :
682693
683694 with torch .no_grad (), quant_inference_mode (model , compile = args .compile_eval ):
684- model (** calibration_loader [ 0 ] )
695+ model (** next ( iter ( calibration_loader )) )
685696 remove_hooks (model )
686697
687698 from brevitas_examples .llm .eval_lighteval import run_lighteval
@@ -703,7 +714,7 @@ def quantize_llm(args, extra_args=None):
703714 print (f"Export to { args .export_target } " )
704715 # Currently we always export with a float32 container to avoid float16 CPU errors
705716 model = model .to (dtype = torch .float32 )
706- model_export (model , tokenizer , calibration_loader [ 0 ] , args , config )
717+ model_export (model , tokenizer , next ( iter ( calibration_loader )) , args , config )
707718
708719 return {"float_ppl" : float_ppl , "quant_ppl" : quant_ppl , ** few_shot_eval_results }, model
709720
0 commit comments