@@ -37,6 +37,7 @@ def get_lora_model(
3737 """
3838 Loads a QLoRA model that has been trained using modelopt trainer.
3939 """
40+ # TODO: Add support for merging adapters in BF16 and merging adapters with quantization for deployment
4041 device_map = "auto"
4142 if device == "cpu" :
4243 device_map = "cpu"
@@ -72,17 +73,17 @@ def main(args):
7273 try :
7374 post_state_dict , hf_quant_config = _export_hf_checkpoint (model , is_modelopt_qlora = True )
7475
75- with open (f"{ export_dir } /base_model /hf_quant_config.json" , "w" ) as file :
76+ with open (f"{ base_model_dir } /hf_quant_config.json" , "w" ) as file :
7677 json .dump (hf_quant_config , file , indent = 4 )
7778
7879 hf_quant_config = convert_hf_quant_config_format (hf_quant_config )
7980
8081 # Save base model
81- model .base_model .save_pretrained (f"{ export_dir } /base_model " , state_dict = post_state_dict )
82+ model .base_model .save_pretrained (f"{ base_model_dir } " , state_dict = post_state_dict )
8283 # Save adapters
8384 model .save_pretrained (export_dir )
8485
85- config_path = f"{ export_dir } /base_model /config.json"
86+ config_path = f"{ base_model_dir } /config.json"
8687
8788 config_data = model .config .to_dict ()
8889
@@ -112,7 +113,11 @@ def main(args):
112113
113114 parser .add_argument ("--device" , default = "cuda" )
114115
115- parser .add_argument ("--export_path" , default = "exported_model" )
116+ parser .add_argument (
117+ "--export_path" ,
118+ default = "exported_model" ,
119+ help = "Path to save the exported model" ,
120+ )
116121
117122 args = parser .parse_args ()
118123
0 commit comments