Skip to content

Commit 114ad19

Browse files
committed
bug fix in benchmark ckpt loading and megatron hf save
1 parent ba33438 commit 114ad19

File tree

2 files changed

+43
-34
lines changed

2 files changed

+43
-34
lines changed

trinity/manager/synchronizer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ def __init__(self, config: Config, module_ref: ray.actor.ActorHandle):
4444
self._modules = {module_ref}
4545
self._modules_lock = asyncio.Lock()
4646
asyncio.create_task(self._check_modules())
47-
if self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT:
47+
if (
48+
self.config.mode != "bench"
49+
and self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT
50+
):
4851
asyncio.create_task(self._find_latest_state_dict())
4952

5053
async def add_module(self, module_ref: ray.actor.ActorHandle) -> None:

trinity/trainer/verl/megatron_checkpoint_manager.py

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -233,45 +233,51 @@ def save_checkpoint( # noqa: C901
233233
json.dump(transformer_config_dict, f, indent=2)
234234

235235
if self.should_save_hf_model or save_as_hf:
236-
# wait for everyone to dump to local
237-
state_dict = self.weight_saver(
238-
self.model,
239-
self.hf_config,
240-
dtype=self.param_dtype,
241-
is_value_model=self.is_value_model,
242-
tie_word_embeddings=self.share_embeddings_and_output_weights,
243-
)
236+
try:
237+
# wait for everyone to dump to local
238+
state_dict = self.weight_saver(
239+
self.model,
240+
self.hf_config,
241+
dtype=self.param_dtype,
242+
is_value_model=self.is_value_model,
243+
tie_word_embeddings=self.share_embeddings_and_output_weights,
244+
)
244245

245-
torch.distributed.barrier()
246-
if self.rank == 0:
247-
# TODO: async save or use mbridge to save hf model
248-
hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)
249-
import warnings
246+
torch.distributed.barrier()
247+
if self.rank == 0:
248+
# TODO: async save or use mbridge to save hf model
249+
hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)
250+
import warnings
250251

251-
from accelerate import init_empty_weights
252+
from accelerate import init_empty_weights
252253

253-
with init_empty_weights(), warnings.catch_warnings():
254-
warnings.simplefilter("ignore")
255-
if "mistral7b-rm" in self.config.model.path:
256-
from transformers import MistralForSequenceClassification
254+
with init_empty_weights(), warnings.catch_warnings():
255+
warnings.simplefilter("ignore")
256+
if "mistral7b-rm" in self.config.model.path:
257+
from transformers import MistralForSequenceClassification
257258

258-
model = MistralForSequenceClassification.from_pretrained(
259-
self.config.model.path
260-
) # use score head instead of lm_head
261-
state_dict["score.weight"] = state_dict["score.weight"]
262-
else:
263-
from transformers import AutoModelForCausalLM
259+
model = MistralForSequenceClassification.from_pretrained(
260+
self.config.model.path
261+
) # use score head instead of lm_head
262+
state_dict["score.weight"] = state_dict["score.weight"]
263+
else:
264+
from transformers import AutoModelForCausalLM
264265

265-
model = AutoModelForCausalLM.from_pretrained(
266-
self.config.model.path, torch_dtype="auto"
267-
)
268-
model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict)
269-
log_with_rank(
270-
f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}",
271-
rank=self.rank,
272-
logger=logger,
273-
log_only_rank_0=True,
266+
model = AutoModelForCausalLM.from_pretrained(
267+
self.config.model.path, torch_dtype="auto"
268+
)
269+
model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict)
270+
log_with_rank(
271+
f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}",
272+
rank=self.rank,
273+
logger=logger,
274+
log_only_rank_0=True,
275+
)
276+
except Exception as e:
277+
logger.error(
278+
f"Failed to save Huggingface model to {local_path}, you can try to set `use_mbridge=true` to save it."
274279
)
280+
logger.error(e)
275281

276282
ray.get(
277283
self.checkpoint_monitor.register_thread_count.remote(

0 commit comments

Comments
 (0)