Skip to content

Commit 4a71cc7

Browse files
committed
fix gpu calculating for train mode
1 parent 92474a1 commit 4a71cc7

File tree

3 files changed

+14
-10
lines changed

3 files changed

+14
-10
lines changed

examples/grpo_gsm8k/gsm8k.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ buffer:
5959
name: gsm8k_buffer
6060
storage_type: queue
6161
path: 'sqlite:///gsm8k.db'
62+
# sft_warmup_steps: 0
6263
# sft_warmup_dataset: # Uncomment these to enable sft warmup
6364
# name: warmup_data
6465
# storage_type: file

trinity/common/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ class TrainerConfig:
266266
# trainer configs
267267
actor_use_kl_loss: bool = False
268268
actor_kl_loss_coef: float = 0.001
269-
actor_entropy_coef: float = 0.001
269+
actor_entropy_coeff: float = 0.001
270270
actor_grad_clip: float = 1.0
271271
actor_clip_ratio: float = 0.2
272272
# TODO: extract more train-related params from underlying trainer engine

trinity/common/verl_config.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -270,16 +270,19 @@ class veRLConfig:
270270

271271
def synchronize_config(self, config: Config) -> None:
272272
"""Synchronize config."""
273-
rollout_gpu_num = (
274-
config.explorer.rollout_model.tensor_parallel_size
275-
* config.explorer.rollout_model.engine_num
276-
+ sum(
277-
[
278-
model.tensor_parallel_size * model.engine_num
279-
for model in config.explorer.auxiliary_models
280-
]
273+
if config.mode != "train":
274+
rollout_gpu_num = (
275+
config.explorer.rollout_model.tensor_parallel_size
276+
* config.explorer.rollout_model.engine_num
277+
+ sum(
278+
[
279+
model.tensor_parallel_size * model.engine_num
280+
for model in config.explorer.auxiliary_models
281+
]
282+
)
281283
)
282-
)
284+
else:
285+
rollout_gpu_num = 0
283286
rollout_node_num = rollout_gpu_num // config.cluster.gpu_per_node
284287
self.trainer.nnodes = config.cluster.node_num - rollout_node_num
285288
self.actor_rollout_ref.model.path = config.model.model_path

0 commit comments

Comments
 (0)