Skip to content

Commit bb2d6ef

Browse files
committed
added TODO
Signed-off-by: Suguna Velury <[email protected]>
1 parent bc6c835 commit bb2d6ef

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

examples/llm_qat/export.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def get_lora_model(
3737
"""
3838
Loads a QLoRA model that has been trained using modelopt trainer.
3939
"""
40+
# TODO: Add support for merging adapters in BF16 and merging adapters with quantization for deployment
4041
device_map = "auto"
4142
if device == "cpu":
4243
device_map = "cpu"
@@ -72,17 +73,17 @@ def main(args):
7273
try:
7374
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, is_modelopt_qlora=True)
7475

75-
with open(f"{export_dir}/base_model/hf_quant_config.json", "w") as file:
76+
with open(f"{base_model_dir}/hf_quant_config.json", "w") as file:
7677
json.dump(hf_quant_config, file, indent=4)
7778

7879
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
7980

8081
# Save base model
81-
model.base_model.save_pretrained(f"{export_dir}/base_model", state_dict=post_state_dict)
82+
model.base_model.save_pretrained(f"{base_model_dir}", state_dict=post_state_dict)
8283
# Save adapters
8384
model.save_pretrained(export_dir)
8485

85-
config_path = f"{export_dir}/base_model/config.json"
86+
config_path = f"{base_model_dir}/config.json"
8687

8788
config_data = model.config.to_dict()
8889

@@ -112,7 +113,11 @@ def main(args):
112113

113114
parser.add_argument("--device", default="cuda")
114115

115-
parser.add_argument("--export_path", default="exported_model")
116+
parser.add_argument(
117+
"--export_path",
118+
default="exported_model",
119+
help="Path to save the exported model",
120+
)
116121

117122
args = parser.parse_args()
118123

0 commit comments

Comments
 (0)