Skip to content

Commit 74b1938

Browse files
committed
Specify role_to_logical_axis_rule in ClusterConfig
1 parent 094b41d commit 74b1938

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

src/MaxText/rl/train_rl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,10 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
361361
rl_cluster_lib.Role.REFERENCE: reference_mesh,
362362
rl_cluster_lib.Role.ROLLOUT: rollout_mesh,
363363
},
364+
role_to_logical_axis_rule={
365+
rl_cluster_lib.Role.ACTOR: trainer_config.logical_axis_rules,
366+
rl_cluster_lib.Role.REFERENCE: trainer_config.logical_axis_rules,
367+
},
364368
rollout_engine="vllm",
365369
offload_to_cpu=False,
366370
training_config=rl_cluster_lib.RLTrainingConfig(

0 commit comments

Comments
 (0)