2121import modelopt .torch .opt as mto
2222import modelopt .torch .quantization as mtq
2323from modelopt .torch .export import get_model_type
24- from modelopt .torch .export .unified_export_hf import export_hf_checkpoint
24+ from modelopt .torch .export .convert_hf_config import convert_hf_quant_config_format
25+ from modelopt .torch .export .unified_export_hf import _export_hf_checkpoint
2526from modelopt .torch .quantization .config import need_calibration
2627from modelopt .torch .quantization .utils import patch_fsdp_mp_dtypes
2728from modelopt .torch .utils .dataset_utils import get_dataset_dataloader , get_supported_datasets
3031RAND_SEED = 1234
3132
3233QUANT_CFG_CHOICES : dict [str , dict [str , Any ]] = {
33- "int8" : mtq .INT8_DEFAULT_CFG ,
34- "int8_sq" : mtq .INT8_SMOOTHQUANT_CFG ,
3534 "int8_wo" : mtq .INT8_WEIGHT_ONLY_CFG ,
3635 "fp8" : mtq .FP8_DEFAULT_CFG ,
3736 "int4_awq" : mtq .INT4_AWQ_CFG ,
38- "w4a8_awq" : mtq .W4A8_AWQ_BETA_CFG ,
3937 "nvfp4" : mtq .NVFP4_DEFAULT_CFG ,
4038 "nvfp4_awq" : mtq .NVFP4_AWQ_LITE_CFG ,
41- "fp8_pb_wo" : mtq .FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG ,
42- "fp8_pc_pt" : mtq .FP8_PER_CHANNEL_PER_TOKEN_CFG ,
43- "w4a8_nvfp4_fp8" : mtq .W4A8_NVFP4_FP8_CFG ,
44- "w4a8_mxfp4_fp8" : mtq .W4A8_MXFP4_FP8_CFG ,
4539 "nvfp4_mlp_only" : mtq .NVFP4_MLP_ONLY_CFG ,
4640}
4741
5246 "nvfp4_affine" : "NVFP4_AFFINE_KV_CFG" ,
5347}
5448
55- SUPPORTED_QFORMATS = [
56- "int8_wo" ,
57- "int4_awq" ,
58- "fp8" ,
59- "nvfp4" ,
60- "nvfp4_awq" ,
61- "w4a8_awq" ,
62- "fp8_pb_wo" ,
63- "w4a8_mxfp4_fp8" ,
64- "nvfp4_mlp_only" ,
65- ]
66-
6749
6850# Enable HuggingFace checkpointing
6951mto .enable_huggingface_checkpointing ()
@@ -83,7 +65,7 @@ def parse_args():
8365 parser .add_argument (
8466 "--qformat" ,
8567 default = "fp8" ,
86- choices = SUPPORTED_QFORMATS ,
68+ choices = QUANT_CFG_CHOICES . keys () ,
8769 help = "Quantization format" ,
8870 )
8971 parser .add_argument (
@@ -290,27 +272,32 @@ def export_model(
290272 export_path: Directory to export model to
291273 """
292274 export_dir = Path (export_path )
275+ export_dir .mkdir (parents = True , exist_ok = True )
293276
294- # Get quantization config
295- export_hf_checkpoint (
296- model ,
297- dtype = torch .bfloat16 ,
298- export_dir = export_dir ,
299- save_modelopt_state = False ,
300- is_fsdp2 = True ,
301- accelerator = accelerator ,
302- )
277+ post_state_dict , hf_quant_config = _export_hf_checkpoint (model , torch .bfloat16 )
278+
279+ if accelerator .is_main_process :
280+ # Save hf_quant_config.json for backward compatibility
281+ with open (f"{ export_dir } /hf_quant_config.json" , "w" ) as file :
282+ json .dump (hf_quant_config , file , indent = 4 )
283+
284+ hf_quant_config = convert_hf_quant_config_format (hf_quant_config )
285+
286+ # Save model
287+ model .save_pretrained (export_dir , state_dict = post_state_dict , save_modelopt_state = False )
288+
289+ original_config = f"{ export_dir } /config.json"
290+ config_data = {}
303291
304- # Update config with quantization info
305- config_path = export_dir / "config.json"
306- with open (config_path ) as f :
307- config_data = json .load (f )
292+ with open (original_config ) as file :
293+ config_data = json .load (file )
308294
309- # Update architectures with original architecture. FSDP prefix must be removed for FSDP wrapped models.
310- config_data ["architectures" ] = architectures
295+ config_data ["quantization_config" ] = hf_quant_config
296+ # Update config architectures to use original architectures that does not have FSDP prefix
297+ config_data ["architectures" ] = architectures
311298
312- with open (config_path , "w" ) as f :
313- json .dump (config_data , f , indent = 4 )
299+ with open (original_config , "w" ) as file :
300+ json .dump (config_data , file , indent = 4 )
314301
315302
316303def main (args ):
@@ -320,9 +307,9 @@ def main(args):
320307 raise OSError ("GPU is required for quantization." )
321308
322309 # Validate quantization format
323- if args .qformat not in SUPPORTED_QFORMATS :
310+ if args .qformat not in QUANT_CFG_CHOICES :
324311 raise ValueError (
325- f"Quantization format { args .qformat } not supported. Choose from: { SUPPORTED_QFORMATS } "
312+ f"Quantization format { args .qformat } not supported. Choose from: { QUANT_CFG_CHOICES . keys () } "
326313 )
327314
328315 # Set random seeds
0 commit comments