Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion trinity/manager/synchronizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def __init__(self, config: Config, module_ref: ray.actor.ActorHandle):
self._modules = {module_ref}
self._modules_lock = asyncio.Lock()
asyncio.create_task(self._check_modules())
if self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT:
if (
self.config.mode != "bench"
and self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT
):
asyncio.create_task(self._find_latest_state_dict())

async def add_module(self, module_ref: ray.actor.ActorHandle) -> None:
Expand Down
72 changes: 39 additions & 33 deletions trinity/trainer/verl/megatron_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,45 +233,51 @@ def save_checkpoint( # noqa: C901
json.dump(transformer_config_dict, f, indent=2)

if self.should_save_hf_model or save_as_hf:
# wait for everyone to dump to local
state_dict = self.weight_saver(
self.model,
self.hf_config,
dtype=self.param_dtype,
is_value_model=self.is_value_model,
tie_word_embeddings=self.share_embeddings_and_output_weights,
)
try:
# wait for everyone to dump to local
state_dict = self.weight_saver(
self.model,
self.hf_config,
dtype=self.param_dtype,
is_value_model=self.is_value_model,
tie_word_embeddings=self.share_embeddings_and_output_weights,
)

torch.distributed.barrier()
if self.rank == 0:
# TODO: async save or use mbridge to save hf model
hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)
import warnings
torch.distributed.barrier()
if self.rank == 0:
# TODO: async save or use mbridge to save hf model
hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)
import warnings

from accelerate import init_empty_weights
from accelerate import init_empty_weights

with init_empty_weights(), warnings.catch_warnings():
warnings.simplefilter("ignore")
if "mistral7b-rm" in self.config.model.path:
from transformers import MistralForSequenceClassification
with init_empty_weights(), warnings.catch_warnings():
warnings.simplefilter("ignore")
if "mistral7b-rm" in self.config.model.path:
from transformers import MistralForSequenceClassification

model = MistralForSequenceClassification.from_pretrained(
self.config.model.path
) # use score head instead of lm_head
state_dict["score.weight"] = state_dict["score.weight"]
else:
from transformers import AutoModelForCausalLM
model = MistralForSequenceClassification.from_pretrained(
self.config.model.path
) # use score head instead of lm_head
state_dict["score.weight"] = state_dict["score.weight"]
else:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
self.config.model.path, torch_dtype="auto"
)
model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict)
log_with_rank(
f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}",
rank=self.rank,
logger=logger,
log_only_rank_0=True,
model = AutoModelForCausalLM.from_pretrained(
self.config.model.path, torch_dtype="auto"
)
model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict)
log_with_rank(
f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}",
rank=self.rank,
logger=logger,
log_only_rank_0=True,
)
except Exception as e:
logger.error(
f"Failed to save Huggingface model to {local_path}, you can try to set `use_mbridge=true` to save it."
)
logger.error(e)

ray.get(
self.checkpoint_monitor.register_thread_count.remote(
Expand Down