@@ -37,6 +37,7 @@ def get_lora_model(
37
37
"""
38
38
Loads a QLoRA model that has been trained using modelopt trainer.
39
39
"""
40
+ # TODO: Add support for merging adapters in BF16 and merging adapters with quantization for deployment
40
41
device_map = "auto"
41
42
if device == "cpu" :
42
43
device_map = "cpu"
@@ -72,17 +73,17 @@ def main(args):
72
73
try :
73
74
post_state_dict , hf_quant_config = _export_hf_checkpoint (model , is_modelopt_qlora = True )
74
75
75
- with open (f"{ export_dir } /base_model /hf_quant_config.json" , "w" ) as file :
76
+ with open (f"{ base_model_dir } /hf_quant_config.json" , "w" ) as file :
76
77
json .dump (hf_quant_config , file , indent = 4 )
77
78
78
79
hf_quant_config = convert_hf_quant_config_format (hf_quant_config )
79
80
80
81
# Save base model
81
- model .base_model .save_pretrained (f"{ export_dir } /base_model " , state_dict = post_state_dict )
82
+ model .base_model .save_pretrained (f"{ base_model_dir } " , state_dict = post_state_dict )
82
83
# Save adapters
83
84
model .save_pretrained (export_dir )
84
85
85
- config_path = f"{ export_dir } /base_model /config.json"
86
+ config_path = f"{ base_model_dir } /config.json"
86
87
87
88
config_data = model .config .to_dict ()
88
89
@@ -112,7 +113,11 @@ def main(args):
112
113
113
114
parser .add_argument ("--device" , default = "cuda" )
114
115
115
- parser .add_argument ("--export_path" , default = "exported_model" )
116
+ parser .add_argument (
117
+ "--export_path" ,
118
+ default = "exported_model" ,
119
+ help = "Path to save the exported model" ,
120
+ )
116
121
117
122
args = parser .parse_args ()
118
123
0 commit comments