From 0ab395f4bfc93a7e11f619cf54cbf6c6c0a9e58e Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Thu, 27 Nov 2025 15:22:47 +0800 Subject: [PATCH] Save bf16 model --- trinity/trainer/verl/fsdp_checkpoint_manager.py | 2 ++ trinity/trainer/verl/megatron_checkpoint_manager.py | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/trinity/trainer/verl/fsdp_checkpoint_manager.py b/trinity/trainer/verl/fsdp_checkpoint_manager.py index 576e1e45fa..c168270366 100644 --- a/trinity/trainer/verl/fsdp_checkpoint_manager.py +++ b/trinity/trainer/verl/fsdp_checkpoint_manager.py @@ -360,6 +360,8 @@ def save_checkpoint( # noqa: C901 if self._save_model_thread is not None: self._save_model_thread.join() + state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()} + def _save_model(): runtime_context = ray.get_runtime_context() node_id = runtime_context.get_node_id() diff --git a/trinity/trainer/verl/megatron_checkpoint_manager.py b/trinity/trainer/verl/megatron_checkpoint_manager.py index f24c8aa4ef..8674934452 100644 --- a/trinity/trainer/verl/megatron_checkpoint_manager.py +++ b/trinity/trainer/verl/megatron_checkpoint_manager.py @@ -257,15 +257,16 @@ def save_checkpoint( # noqa: C901 from transformers import MistralForSequenceClassification model = MistralForSequenceClassification.from_pretrained( - self.config.model.path + self.config.model.path, torch_dtype=torch.bfloat16 ) # use score head instead of lm_head state_dict["score.weight"] = state_dict["score.weight"] else: from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained( - self.config.model.path, torch_dtype="auto" + self.config.model.path, torch_dtype=torch.bfloat16 ) + state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()} model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict) log_with_rank( f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}",