2121import torch
2222from transformers import AutoModelForCausalLM , AutoTokenizer
2323
24+ import modelopt .torch .opt as mto
2425from modelopt .torch .export .convert_hf_config import convert_hf_quant_config_format
2526from modelopt .torch .export .unified_export_hf import _export_hf_checkpoint
2627from modelopt .torch .opt .conversion import restore_from_modelopt_state
2930
3031RAND_SEED = 1234
3132
33+ # Enable automatic save/load of modelopt state huggingface checkpointing
34+ mto .enable_huggingface_checkpointing ()
35+
3236
3337def get_lora_model (
3438 ckpt_path : str ,
@@ -42,19 +46,20 @@ def get_lora_model(
4246 if device == "cpu" :
4347 device_map = "cpu"
4448
45- # Load model with adapters
49+ # Load model
4650 model = AutoModelForCausalLM .from_pretrained (ckpt_path , device_map = device_map )
4751
48- # Restore modelopt state
49- modelopt_state = torch .load (f"{ ckpt_path } /modelopt_state_calib.pth" , weights_only = False )
50- restore_from_modelopt_state (model , modelopt_state )
51- print_rank_0 ("Restored modelopt state" )
52+ # Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this
53+ if hasattr (model , "peft_config" ):
54+ modelopt_state = torch .load (f"{ ckpt_path } /modelopt_state_train.pth" , weights_only = False )
55+ restore_from_modelopt_state (model , modelopt_state )
56+ print_rank_0 ("Restored modelopt state" )
5257
53- # Restore modelopt quantizer state dict
54- modelopt_weights = modelopt_state .pop ("modelopt_state_weights" , None )
55- if modelopt_weights is not None :
56- set_quantizer_state_dict (model , modelopt_weights )
57- print_rank_0 ("Restored modelopt quantizer state dict" )
58+ # Restore modelopt quantizer state dict
59+ modelopt_weights = modelopt_state .pop ("modelopt_state_weights" , None )
60+ if modelopt_weights is not None :
61+ set_quantizer_state_dict (model , modelopt_weights )
62+ print_rank_0 ("Restored modelopt quantizer state dict" )
5863
5964 return model
6065
@@ -63,25 +68,31 @@ def main(args):
6368 # Load model
6469 model = get_lora_model (args .pyt_ckpt_path , args .device )
6570 tokenizer = AutoTokenizer .from_pretrained (args .pyt_ckpt_path )
71+ is_qlora = hasattr (model , "peft_config" )
6672
6773 # Export HF checkpoint
6874 export_dir = Path (args .export_path )
6975 export_dir .mkdir (parents = True , exist_ok = True )
70- base_model_dir = export_dir / "base_model"
71- base_model_dir .mkdir (parents = True , exist_ok = True )
76+ if is_qlora :
77+ base_model_dir = export_dir / "base_model"
78+ base_model_dir .mkdir (parents = True , exist_ok = True )
79+ else :
80+ base_model_dir = export_dir
7281
7382 try :
74- post_state_dict , hf_quant_config = _export_hf_checkpoint (model , is_modelopt_qlora = True )
83+ post_state_dict , hf_quant_config = _export_hf_checkpoint (model , is_modelopt_qlora = is_qlora )
7584
7685 with open (f"{ base_model_dir } /hf_quant_config.json" , "w" ) as file :
7786 json .dump (hf_quant_config , file , indent = 4 )
7887
7988 hf_quant_config = convert_hf_quant_config_format (hf_quant_config )
8089
81- # Save base model
82- model .base_model .save_pretrained (f"{ base_model_dir } " , state_dict = post_state_dict )
83- # Save adapters
84- model .save_pretrained (export_dir )
90+ # Save model
91+ if is_qlora :
92+ model .base_model .save_pretrained (f"{ base_model_dir } " , state_dict = post_state_dict )
93+ model .save_pretrained (export_dir )
94+ else :
95+ model .save_pretrained (export_dir , state_dict = post_state_dict )
8596
8697 config_path = f"{ base_model_dir } /config.json"
8798
0 commit comments