Skip to content

Commit 8ca6e57

Browse files
authored
Save bf16 model (#414)
1 parent 0b17681 commit 8ca6e57

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

trinity/trainer/verl/fsdp_checkpoint_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,8 @@ def save_checkpoint( # noqa: C901
360360
if self._save_model_thread is not None:
361361
self._save_model_thread.join()
362362

363+
state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
364+
363365
def _save_model():
364366
runtime_context = ray.get_runtime_context()
365367
node_id = runtime_context.get_node_id()

trinity/trainer/verl/megatron_checkpoint_manager.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,15 +257,16 @@ def save_checkpoint( # noqa: C901
257257
from transformers import MistralForSequenceClassification
258258

259259
model = MistralForSequenceClassification.from_pretrained(
260-
self.config.model.path
260+
self.config.model.path, torch_dtype=torch.bfloat16
261261
) # use score head instead of lm_head
262262
state_dict["score.weight"] = state_dict["score.weight"]
263263
else:
264264
from transformers import AutoModelForCausalLM
265265

266266
model = AutoModelForCausalLM.from_pretrained(
267-
self.config.model.path, torch_dtype="auto"
267+
self.config.model.path, torch_dtype=torch.bfloat16
268268
)
269+
state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
269270
model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict)
270271
log_with_rank(
271272
f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}",

0 commit comments

Comments
 (0)