21
21
import torch
22
22
from transformers import AutoModelForCausalLM , AutoTokenizer
23
23
24
+ import modelopt .torch .opt as mto
24
25
from modelopt .torch .export .convert_hf_config import convert_hf_quant_config_format
25
26
from modelopt .torch .export .unified_export_hf import _export_hf_checkpoint
26
27
from modelopt .torch .opt .conversion import restore_from_modelopt_state
29
30
30
31
RAND_SEED = 1234
31
32
33
+ # Enable automatic save/load of modelopt state huggingface checkpointing
34
+ mto .enable_huggingface_checkpointing ()
35
+
32
36
33
37
def get_lora_model (
34
38
ckpt_path : str ,
@@ -42,19 +46,20 @@ def get_lora_model(
42
46
if device == "cpu" :
43
47
device_map = "cpu"
44
48
45
- # Load model with adapters
49
+ # Load model
46
50
model = AutoModelForCausalLM .from_pretrained (ckpt_path , device_map = device_map )
47
51
48
- # Restore modelopt state
49
- modelopt_state = torch .load (f"{ ckpt_path } /modelopt_state_calib.pth" , weights_only = False )
50
- restore_from_modelopt_state (model , modelopt_state )
51
- print_rank_0 ("Restored modelopt state" )
52
+ # Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this
53
+ if hasattr (model , "peft_config" ):
54
+ modelopt_state = torch .load (f"{ ckpt_path } /modelopt_state_train.pth" , weights_only = False )
55
+ restore_from_modelopt_state (model , modelopt_state )
56
+ print_rank_0 ("Restored modelopt state" )
52
57
53
- # Restore modelopt quantizer state dict
54
- modelopt_weights = modelopt_state .pop ("modelopt_state_weights" , None )
55
- if modelopt_weights is not None :
56
- set_quantizer_state_dict (model , modelopt_weights )
57
- print_rank_0 ("Restored modelopt quantizer state dict" )
58
+ # Restore modelopt quantizer state dict
59
+ modelopt_weights = modelopt_state .pop ("modelopt_state_weights" , None )
60
+ if modelopt_weights is not None :
61
+ set_quantizer_state_dict (model , modelopt_weights )
62
+ print_rank_0 ("Restored modelopt quantizer state dict" )
58
63
59
64
return model
60
65
@@ -63,25 +68,31 @@ def main(args):
63
68
# Load model
64
69
model = get_lora_model (args .pyt_ckpt_path , args .device )
65
70
tokenizer = AutoTokenizer .from_pretrained (args .pyt_ckpt_path )
71
+ is_qlora = hasattr (model , "peft_config" )
66
72
67
73
# Export HF checkpoint
68
74
export_dir = Path (args .export_path )
69
75
export_dir .mkdir (parents = True , exist_ok = True )
70
- base_model_dir = export_dir / "base_model"
71
- base_model_dir .mkdir (parents = True , exist_ok = True )
76
+ if is_qlora :
77
+ base_model_dir = export_dir / "base_model"
78
+ base_model_dir .mkdir (parents = True , exist_ok = True )
79
+ else :
80
+ base_model_dir = export_dir
72
81
73
82
try :
74
- post_state_dict , hf_quant_config = _export_hf_checkpoint (model , is_modelopt_qlora = True )
83
+ post_state_dict , hf_quant_config = _export_hf_checkpoint (model , is_modelopt_qlora = is_qlora )
75
84
76
85
with open (f"{ base_model_dir } /hf_quant_config.json" , "w" ) as file :
77
86
json .dump (hf_quant_config , file , indent = 4 )
78
87
79
88
hf_quant_config = convert_hf_quant_config_format (hf_quant_config )
80
89
81
- # Save base model
82
- model .base_model .save_pretrained (f"{ base_model_dir } " , state_dict = post_state_dict )
83
- # Save adapters
84
- model .save_pretrained (export_dir )
90
+ # Save model
91
+ if is_qlora :
92
+ model .base_model .save_pretrained (f"{ base_model_dir } " , state_dict = post_state_dict )
93
+ model .save_pretrained (export_dir )
94
+ else :
95
+ model .save_pretrained (export_dir , state_dict = post_state_dict )
85
96
86
97
config_path = f"{ base_model_dir } /config.json"
87
98
0 commit comments