Skip to content

Commit cebd16f

Browse files
committed
fix unittest
1 parent 05672c7 commit cebd16f

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

trinity/common/config_validator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,14 @@ def validate(self, config: Config) -> None:
777777
config.buffer.train_batch_size = (
778778
config.buffer.batch_size * config.algorithm.repeat_times
779779
)
780+
if (
781+
config.mode in {"train", "both"}
782+
and config.buffer.train_batch_size % config.cluster.trainer_gpu_num != 0
783+
):
784+
raise ValueError(
785+
f"batch_size ({config.buffer.train_batch_size}) must be "
786+
f"divisible by ({config.cluster.trainer_gpu_num})."
787+
)
780788

781789
# create buffer.cache_dir at <checkpoint_root_dir>/<project>/<name>/buffer
782790
config.buffer.cache_dir = os.path.abspath(os.path.join(config.checkpoint_job_dir, "buffer"))

trinity/common/verl_config.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -421,13 +421,9 @@ def _adjust_token_len_if_needed(
421421

422422
def synchronize_config(self, config: Config) -> None: # noqa: C901
423423
"""Synchronize config."""
424+
# Trainer Config
424425
self.trainer.nnodes = config.cluster.trainer_node_num
425426
self.trainer.n_gpus_per_node = config.cluster.trainer_gpu_num_per_node
426-
world_size = config.cluster.trainer_gpu_num
427-
if config.buffer.train_batch_size % world_size != 0:
428-
raise ValueError(
429-
f"batch_size ({config.buffer.train_batch_size}) must be divisible by ({world_size})"
430-
)
431427
self.trainer.total_training_steps = config.trainer.total_steps or sys.maxsize
432428
self.trainer.sync_freq = config.synchronizer.sync_interval
433429
self.trainer.save_freq = config.trainer.save_interval
@@ -443,9 +439,8 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
443439
else:
444440
self.trainer.resume_mode = "auto"
445441

446-
self.data.train_batch_size = (
447-
config.buffer.train_batch_size
448-
) # kept to pass RayPPOTrainer._validate_config
442+
# kept to pass RayPPOTrainer._validate_config
443+
self.data.train_batch_size = config.buffer.train_batch_size
449444

450445
self.synchronizer = config.synchronizer
451446
self.actor_rollout_ref.nccl_timeout = config.synchronizer.sync_timeout

0 commit comments

Comments
 (0)