@@ -233,45 +233,51 @@ def save_checkpoint( # noqa: C901
233233 json .dump (transformer_config_dict , f , indent = 2 )
234234
235235 if self .should_save_hf_model or save_as_hf :
236- # wait for everyone to dump to local
237- state_dict = self .weight_saver (
238- self .model ,
239- self .hf_config ,
240- dtype = self .param_dtype ,
241- is_value_model = self .is_value_model ,
242- tie_word_embeddings = self .share_embeddings_and_output_weights ,
243- )
236+ try :
237+ # wait for everyone to dump to local
238+ state_dict = self .weight_saver (
239+ self .model ,
240+ self .hf_config ,
241+ dtype = self .param_dtype ,
242+ is_value_model = self .is_value_model ,
243+ tie_word_embeddings = self .share_embeddings_and_output_weights ,
244+ )
244245
245- torch .distributed .barrier ()
246- if self .rank == 0 :
247- # TODO: async save or use mbridge to save hf model
248- hf_model_ckpt_path = get_hf_model_checkpoint_path (local_path )
249- import warnings
246+ torch .distributed .barrier ()
247+ if self .rank == 0 :
248+ # TODO: async save or use mbridge to save hf model
249+ hf_model_ckpt_path = get_hf_model_checkpoint_path (local_path )
250+ import warnings
250251
251- from accelerate import init_empty_weights
252+ from accelerate import init_empty_weights
252253
253- with init_empty_weights (), warnings .catch_warnings ():
254- warnings .simplefilter ("ignore" )
255- if "mistral7b-rm" in self .config .model .path :
256- from transformers import MistralForSequenceClassification
254+ with init_empty_weights (), warnings .catch_warnings ():
255+ warnings .simplefilter ("ignore" )
256+ if "mistral7b-rm" in self .config .model .path :
257+ from transformers import MistralForSequenceClassification
257258
258- model = MistralForSequenceClassification .from_pretrained (
259- self .config .model .path
260- ) # use score head instead of lm_head
261- state_dict ["score.weight" ] = state_dict ["score.weight" ]
262- else :
263- from transformers import AutoModelForCausalLM
259+ model = MistralForSequenceClassification .from_pretrained (
260+ self .config .model .path
261+ ) # use score head instead of lm_head
262+ state_dict ["score.weight" ] = state_dict ["score.weight" ]
263+ else :
264+ from transformers import AutoModelForCausalLM
264265
265- model = AutoModelForCausalLM .from_pretrained (
266- self .config .model .path , torch_dtype = "auto"
267- )
268- model .save_pretrained (hf_model_ckpt_path , state_dict = state_dict )
269- log_with_rank (
270- f"Saved Huggingface config and tokenizer to { hf_model_ckpt_path } " ,
271- rank = self .rank ,
272- logger = logger ,
273- log_only_rank_0 = True ,
266+ model = AutoModelForCausalLM .from_pretrained (
267+ self .config .model .path , torch_dtype = "auto"
268+ )
269+ model .save_pretrained (hf_model_ckpt_path , state_dict = state_dict )
270+ log_with_rank (
271+ f"Saved Huggingface config and tokenizer to { hf_model_ckpt_path } " ,
272+ rank = self .rank ,
273+ logger = logger ,
274+ log_only_rank_0 = True ,
275+ )
276+ except Exception as e :
277+ logger .error (
278+ f"Failed to save Huggingface model to { local_path } , you can try to set `use_mbridge=true` to save it."
274279 )
280+ logger .error (e )
275281
276282 ray .get (
277283 self .checkpoint_monitor .register_thread_count .remote (
0 commit comments