@@ -206,6 +206,36 @@ def setup_configs_and_devices(argv: Sequence[str]):
206206 return trainer_config , sampler_config , trainer_devices , sampler_devices
207207
208208
209+ def get_rollout_kwargs_for_data_parallelism (sampler_config , num_sampler_devices ):
210+ """Get rollout kwargs for vLLM rollout when using data parallelism."""
211+ dp = sampler_config .rollout_data_parallelism
212+ if dp == - 1 :
213+ return {}
214+
215+ rollout_kwargs = {}
216+ tp = sampler_config .rollout_tensor_parallelism
217+
218+ if tp == - 1 :
219+ if num_sampler_devices % dp != 0 :
220+ raise ValueError (
221+ f"num_sampler_devices({ num_sampler_devices } ) must be divisible by "
222+ f"rollout_data_parallelism({ dp } ) "
223+ f"when rollout_tensor_parallelism is -1."
224+ )
225+ tp = num_sampler_devices // dp
226+ elif tp * dp != num_sampler_devices :
227+ raise ValueError (
228+ f"rollout_tensor_parallelism({ tp } ) * "
229+ f"rollout_data_parallelism({ dp } ) "
230+ f"!= len(sampler_devices)({ num_sampler_devices } )"
231+ )
232+ rollout_kwargs ["tensor_parallel_size" ] = tp
233+ rollout_kwargs ["data_parallel_size" ] = dp
234+ rollout_kwargs ["rollout_vllm_async_scheduling" ] = True
235+
236+ return rollout_kwargs
237+
238+
209239def rl_train (trainer_config , sampler_config , trainer_devices , sampler_devices ):
210240 """
211241 Run RL training with the provided configuration.
@@ -360,6 +390,7 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
360390 rollout_vllm_hbm_utilization = trainer_config .hbm_utilization_vllm ,
361391 rollout_vllm_tpu_backend_type = "jax" ,
362392 rollout_vllm_swap_space_size_gb = trainer_config .swap_space_vllm_gb ,
393+ ** get_rollout_kwargs_for_data_parallelism (sampler_config , len (sampler_devices )),
363394 ),
364395 )
365396 grpo_config = GrpoConfig (
0 commit comments