2121import modelopt .torch .opt as mto
2222import modelopt .torch .quantization as mtq
2323from modelopt .torch .export import get_model_type
24- from modelopt .torch .export .convert_hf_config import convert_hf_quant_config_format
25- from modelopt .torch .export .quant_utils import postprocess_state_dict
26- from modelopt .torch .export .unified_export_hf import _export_hf_checkpoint
24+ from modelopt .torch .export .unified_export_hf import export_hf_checkpoint
2725from modelopt .torch .quantization .config import need_calibration
2826from modelopt .torch .quantization .utils import patch_fsdp_mp_dtypes
2927from modelopt .torch .utils .dataset_utils import get_dataset_dataloader , get_supported_datasets
@@ -121,11 +119,6 @@ def parse_args():
121119 action = "store_true" ,
122120 help = "Trust remote code for HuggingFace models" ,
123121 )
124- parser .add_argument (
125- "--attn_implementation" ,
126- type = str ,
127- help = "Attention implementation to use (passed to HF model loading)" ,
128- )
129122 parser .add_argument ("--awq_block_size" , default = 0 , type = int )
130123
131124 args = parser .parse_args ()
@@ -159,6 +152,8 @@ def load_and_prepare_model(
159152 )
160153 model .eval ()
161154 model_type = get_model_type (model )
155+ # Need the original architectures for export
156+ # FSDP prefix is added to the architectures for FSDP2 wrapped models
162157 original_architectures = model .config .architectures
163158
164159 # FSDP2 requires an optimizer to be prepared together with the model
@@ -274,6 +269,8 @@ def calibrate(unwrapped_model):
274269 for k , v in batch .items ()
275270 }
276271 # Use outer model (FSDP-wrapped), not the parameter
272+ # Important: We should forward pass using the unwrapped model
273+ # mtq.quantize will unwrap the model & pass to the forward_loop
277274 model (** batch )
278275
279276 return calibrate
@@ -293,41 +290,27 @@ def export_model(
293290 export_path: Directory to export model to
294291 """
295292 export_dir = Path (export_path )
296- export_dir .mkdir (parents = True , exist_ok = True )
297293
298294 # Get quantization config
299- _ , hf_quant_config = _export_hf_checkpoint (model , dtype = torch .bfloat16 )
300-
301- # Gather and post-process state dict
302- model_state_dict = accelerator .get_state_dict (model )
303- post_state_dict = postprocess_state_dict (model_state_dict , 1.0 , None )
304-
305- # Save quantization config
306- if accelerator .is_main_process :
307- with open (export_dir / "hf_quant_config.json" , "w" ) as f :
308- json .dump (hf_quant_config , f , indent = 4 )
309-
310- # Convert config format
311- hf_quant_config = convert_hf_quant_config_format (hf_quant_config )
312-
313- # Save model
314- model .save_pretrained (
315- export_dir ,
316- state_dict = post_state_dict ,
317- save_modelopt_state = False ,
318- )
295+ export_hf_checkpoint (
296+ model ,
297+ dtype = torch .bfloat16 ,
298+ export_dir = export_dir ,
299+ save_modelopt_state = False ,
300+ is_fsdp2 = True ,
301+ accelerator = accelerator ,
302+ )
319303
320- # Update config with quantization info
321- config_path = export_dir / "config.json"
322- with open (config_path ) as f :
323- config_data = json .load (f )
304+ # Update config with quantization info
305+ config_path = export_dir / "config.json"
306+ with open (config_path ) as f :
307+ config_data = json .load (f )
324308
325- config_data ["quantization_config" ] = hf_quant_config
326- # Update architectures with original architecture. FSDP prefix must be removed for FSDP wrapped models.
327- config_data ["architectures" ] = architectures
309+ # Update architectures with original architecture. FSDP prefix must be removed for FSDP wrapped models.
310+ config_data ["architectures" ] = architectures
328311
329- with open (config_path , "w" ) as f :
330- json .dump (config_data , f , indent = 4 )
312+ with open (config_path , "w" ) as f :
313+ json .dump (config_data , f , indent = 4 )
331314
332315
333316def main (args ):
@@ -402,10 +385,13 @@ def main(args):
402385 print (f"Quantization completed in { elapsed :.2f} s" )
403386 mtq .print_quant_summary (model )
404387
388+ start_time = time .time ()
405389 export_model (model , accelerator , args .export_path , original_architectures )
390+ elapsed = time .time () - start_time
406391
407392 if accelerator .is_main_process :
408393 # Export the model
394+ print (f"Export completed in { elapsed :.2f} s" )
409395 print (f"Model exported to { args .export_path } " )
410396
411397 print ("Unpatching FSDP2 MP dtypes" )
0 commit comments