Skip to content

Commit 3d9378d

Browse files
Merge pull request #2802 from AI-Hypercomputer:rl-logprob
PiperOrigin-RevId: 842565881
2 parents 094b41d + 74b1938 commit 3d9378d

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

src/MaxText/rl/train_rl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,10 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
361361
rl_cluster_lib.Role.REFERENCE: reference_mesh,
362362
rl_cluster_lib.Role.ROLLOUT: rollout_mesh,
363363
},
364+
role_to_logical_axis_rule={
365+
rl_cluster_lib.Role.ACTOR: trainer_config.logical_axis_rules,
366+
rl_cluster_lib.Role.REFERENCE: trainer_config.logical_axis_rules,
367+
},
364368
rollout_engine="vllm",
365369
offload_to_cpu=False,
366370
training_config=rl_cluster_lib.RLTrainingConfig(

0 commit comments

Comments
 (0)