Skip to content

Commit 50866a2

Browse files
committed
apply reviews from gemini
1 parent 0992dd4 commit 50866a2

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

trinity/trainer/verl/fsdp_checkpoint_manager.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -466,18 +466,22 @@ def save_checkpoint(
466466
self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg
467467
):
468468
if self.should_save_model:
469-
state_dict_thread_count += self._save_model(local_path, global_step)
469+
if self._save_model(local_path, global_step):
470+
state_dict_thread_count += 1
470471

471472
if self.should_save_optimizer:
472-
checkpoint_thread_count += self._save_optimizer(local_path, global_step)
473+
if self._save_optimizer(local_path, global_step):
474+
checkpoint_thread_count += 1
473475

474476
if self.should_save_extra:
475-
checkpoint_thread_count += self._save_extra_state(local_path, global_step)
477+
if self._save_extra_state(local_path, global_step):
478+
checkpoint_thread_count += 1
476479

477480
self._save_tokenizer(local_path, global_step)
478481

479482
if self.should_save_hf_model or save_as_hf:
480-
checkpoint_thread_count += self._save_hf_model(local_path, global_step)
483+
if self._save_hf_model(local_path, global_step):
484+
checkpoint_thread_count += 1
481485

482486
ray.get(
483487
self.checkpoint_monitor.register_thread_count.remote(

trinity/trainer/verl/megatron_checkpoint_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,8 @@ def save_checkpoint(
347347

348348
state_dict_thread_count = 0
349349
if self.should_save_model:
350-
state_dict_thread_count += self._save_state_dict(local_path, global_step)
350+
if self._save_state_dict(local_path, global_step):
351+
state_dict_thread_count += 1
351352

352353
self._save_tokenizer(local_path, global_step)
353354

0 commit comments

Comments
 (0)