@@ -51,13 +51,15 @@ def get_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir: str, re
5151 config = AutoConfig .from_pretrained (
5252 model_args .config_name ,
5353 token = model_args .token ,
54- cache_dir = model_args .cache_dir
54+ cache_dir = model_args .cache_dir ,
55+ trust_remote_code = model_args .trust_remote_code ,
5556 )
5657 elif model_args .model_name_or_path :
5758 config = AutoConfig .from_pretrained (
5859 model_args .model_name_or_path ,
5960 token = model_args .token ,
60- cache_dir = model_args .cache_dir
61+ cache_dir = model_args .cache_dir ,
62+ trust_remote_code = model_args .trust_remote_code ,
6163 )
6264 else :
6365 raise ValueError (
@@ -74,6 +76,7 @@ def get_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir: str, re
7476 cache_dir = model_args .cache_dir ,
7577 from_tf = bool (".ckpt" in model_args .model_name_or_path ),
7678 config = config ,
79+ trust_remote_code = model_args .trust_remote_code ,
7780 )
7881 else :
7982 logger .info ("Training new model from scratch" )
@@ -129,13 +132,15 @@ def save_merged_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir:
129132 config = AutoConfig .from_pretrained (
130133 model_args .config_name ,
131134 token = model_args .token ,
132- cache_dir = model_args .cache_dir
135+ cache_dir = model_args .cache_dir ,
136+ trust_remote_code = model_args .trust_remote_code ,
133137 )
134138 elif model_args .model_name_or_path :
135139 config = AutoConfig .from_pretrained (
136140 model_args .model_name_or_path ,
137141 token = model_args .token ,
138- cache_dir = model_args .cache_dir
142+ cache_dir = model_args .cache_dir ,
143+ trust_remote_code = model_args .trust_remote_code ,
139144 )
140145 else :
141146 raise ValueError (
@@ -152,6 +157,7 @@ def save_merged_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir:
152157 cache_dir = model_args .cache_dir ,
153158 from_tf = bool (".ckpt" in model_args .model_name_or_path ),
154159 config = config ,
160+ trust_remote_code = model_args .trust_remote_code ,
155161 )
156162 else :
157163 model = model_args .from_config (config )
@@ -171,7 +177,9 @@ def save_merged_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir:
171177 model = PeftModel .from_pretrained (model , find_largest_checkpoint (output_dir ))
172178 model = model .merge_and_unload ()
173179
174- model .save_pretrained (os .path .join (output_dir , 'merged_model' ))
175-
176- tokenizer = AutoTokenizer .from_pretrained (output_dir )
180+ tokenizer = AutoTokenizer .from_pretrained (output_dir , trust_remote_code = model_args .trust_remote_code )
177181 tokenizer .save_pretrained (os .path .join (output_dir , 'merged_model' ))
182+
183+ # modify the vocab size in the model configuration
184+ model .config .vocab_size = len (tokenizer )
185+ model .save_pretrained (os .path .join (output_dir , 'merged_model' ))
0 commit comments