diff --git a/trinity/trainer/verl/fsdp_checkpoint_manager.py b/trinity/trainer/verl/fsdp_checkpoint_manager.py index 576e1e45fa..c168270366 100644 --- a/trinity/trainer/verl/fsdp_checkpoint_manager.py +++ b/trinity/trainer/verl/fsdp_checkpoint_manager.py @@ -360,6 +360,8 @@ def save_checkpoint( # noqa: C901 if self._save_model_thread is not None: self._save_model_thread.join() + state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()} + def _save_model(): runtime_context = ray.get_runtime_context() node_id = runtime_context.get_node_id() diff --git a/trinity/trainer/verl/megatron_checkpoint_manager.py b/trinity/trainer/verl/megatron_checkpoint_manager.py index f24c8aa4ef..8674934452 100644 --- a/trinity/trainer/verl/megatron_checkpoint_manager.py +++ b/trinity/trainer/verl/megatron_checkpoint_manager.py @@ -257,15 +257,16 @@ def save_checkpoint( # noqa: C901 from transformers import MistralForSequenceClassification model = MistralForSequenceClassification.from_pretrained( - self.config.model.path + self.config.model.path, torch_dtype=torch.bfloat16 ) # use score head instead of lm_head state_dict["score.weight"] = state_dict["score.weight"] else: from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained( - self.config.model.path, torch_dtype="auto" + self.config.model.path, torch_dtype=torch.bfloat16 ) + state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()} model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict) log_with_rank( f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}",