File tree Expand file tree Collapse file tree 2 files changed +5
-2
lines changed
Expand file tree Collapse file tree 2 files changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -360,6 +360,8 @@ def save_checkpoint( # noqa: C901
360360 if self ._save_model_thread is not None :
361361 self ._save_model_thread .join ()
362362
363+ state_dict = {k : v .to (torch .bfloat16 ) for k , v in state_dict .items ()}
364+
363365 def _save_model ():
364366 runtime_context = ray .get_runtime_context ()
365367 node_id = runtime_context .get_node_id ()
Original file line number Diff line number Diff line change @@ -257,15 +257,16 @@ def save_checkpoint( # noqa: C901
257257 from transformers import MistralForSequenceClassification
258258
259259 model = MistralForSequenceClassification .from_pretrained (
260- self .config .model .path
260+ self .config .model .path , torch_dtype = torch . bfloat16
261261 ) # use score head instead of lm_head
262262 state_dict ["score.weight" ] = state_dict ["score.weight" ]
263263 else :
264264 from transformers import AutoModelForCausalLM
265265
266266 model = AutoModelForCausalLM .from_pretrained (
267- self .config .model .path , torch_dtype = "auto"
267+ self .config .model .path , torch_dtype = torch . bfloat16
268268 )
269+ state_dict = {k : v .to (torch .bfloat16 ) for k , v in state_dict .items ()}
269270 model .save_pretrained (hf_model_ckpt_path , state_dict = state_dict )
270271 log_with_rank (
271272 f"Saved Huggingface config and tokenizer to { hf_model_ckpt_path } " ,
You can’t perform that action at this time.
0 commit comments