Skip to content

Commit 094b41d

Browse files
Merge pull request #2713 from AI-Hypercomputer:xfgu-dp-rl
PiperOrigin-RevId: 842472785
2 parents 6f35ece + 53e704c commit 094b41d

File tree

3 files changed

+43
-0
lines changed

3 files changed

+43
-0
lines changed

src/MaxText/configs/rl.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ sampler_devices_fraction: 0.5
2323
chips_per_vm: 4 # depends on hardware, for v5p this is 4
2424
num_trainer_slices: -1
2525
num_samplers_slices: -1
26+
# Only specify rollout_data_parallelism when you would like to use more than one model
27+
# replicas in rollout. If not specified, rollout_tensor_parallelism will be auto-determined.
28+
rollout_data_parallelism: -1
29+
rollout_tensor_parallelism: -1
2630

2731
# ====== Reproducibility ======
2832
data_shuffle_seed: 42

src/MaxText/configs/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,6 +1333,14 @@ class RLHardware(BaseModel):
13331333
use_pathways: bool = Field(True, description="Whether to use Pathways for multihost orchestration.")
13341334
num_trainer_slices: int = Field(-1, description="Number of slices for the trainer.")
13351335
num_samplers_slices: int = Field(-1, description="Number of slices for the samplers.")
1336+
rollout_data_parallelism: int = Field(
1337+
-1,
1338+
description="Total model replicas for rollout. It should only be specified when you would like to use more "
1339+
"than one model replica in rollout.",
1340+
)
1341+
rollout_tensor_parallelism: int = Field(
1342+
-1, description="Tensor parallelism per replica for rollout. If not specified, it will be auto-determined."
1343+
)
13361344

13371345

13381346
class VLLM(BaseModel):

src/MaxText/rl/train_rl.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,36 @@ def setup_configs_and_devices(argv: Sequence[str]):
206206
return trainer_config, sampler_config, trainer_devices, sampler_devices
207207

208208

209+
def get_rollout_kwargs_for_data_parallelism(sampler_config, num_sampler_devices):
210+
"""Get rollout kwargs for vLLM rollout when using data parallelism."""
211+
dp = sampler_config.rollout_data_parallelism
212+
if dp == -1:
213+
return {}
214+
215+
rollout_kwargs = {}
216+
tp = sampler_config.rollout_tensor_parallelism
217+
218+
if tp == -1:
219+
if num_sampler_devices % dp != 0:
220+
raise ValueError(
221+
f"num_sampler_devices({num_sampler_devices}) must be divisible by "
222+
f"rollout_data_parallelism({dp}) "
223+
f"when rollout_tensor_parallelism is -1."
224+
)
225+
tp = num_sampler_devices // dp
226+
elif tp * dp != num_sampler_devices:
227+
raise ValueError(
228+
f"rollout_tensor_parallelism({tp}) * "
229+
f"rollout_data_parallelism({dp}) "
230+
f"!= len(sampler_devices)({num_sampler_devices})"
231+
)
232+
rollout_kwargs["tensor_parallel_size"] = tp
233+
rollout_kwargs["data_parallel_size"] = dp
234+
rollout_kwargs["rollout_vllm_async_scheduling"] = True
235+
236+
return rollout_kwargs
237+
238+
209239
def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
210240
"""
211241
Run RL training with the provided configuration.
@@ -360,6 +390,7 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
360390
rollout_vllm_hbm_utilization=trainer_config.hbm_utilization_vllm,
361391
rollout_vllm_tpu_backend_type="jax",
362392
rollout_vllm_swap_space_size_gb=trainer_config.swap_space_vllm_gb,
393+
**get_rollout_kwargs_for_data_parallelism(sampler_config, len(sampler_devices)),
363394
),
364395
)
365396
grpo_config = GrpoConfig(

0 commit comments

Comments
 (0)