3131from tunix .rl import rl_cluster as rl_cluster_lib
3232from tunix .rl .grpo import grpo_learner
3333from tunix .rl .rollout import base_rollout
34+ from typing import Any
3435
3536GrpoConfig = grpo_learner .GrpoConfig
3637
@@ -154,7 +155,7 @@ def create_perf_config(self, cluster_config: rl_cluster_lib.ClusterConfig):
154155 )
155156 return perf_config
156157
157- def create_rl_cluster (self ):
158+ def create_rl_cluster (self , tokenizer ):
158159 # Should not use LoRA for reference model.
159160 if self .config ["reference_model_config" ].get ("lora_config" ):
160161 logging .warning (
@@ -180,10 +181,6 @@ def create_rl_cluster(self):
180181 jax .tree .map (jnp .copy , params ),
181182 )
182183
183- tokenizer = model_lib .create_tokenizer (
184- self .config ["tokenizer_config" ], tokenizer_path
185- )
186-
187184 cluster_config = self .create_cluster_config ()
188185 perf_config = self .create_perf_config (cluster_config )
189186 return rl_cluster_lib .RLCluster (
@@ -194,14 +191,76 @@ def create_rl_cluster(self):
194191 perf_config = perf_config ,
195192 )
196193
194+ def compute_params (self , dataset ):
195+ rl_training_config : dict [str , Any ] = self .config .get (
196+ "rl_training_config" , {}
197+ )
198+
199+ # Return early if max_steps is already specified.
200+ max_steps = None
201+ if rl_training_config .get ("max_steps" ):
202+ max_steps = rl_training_config .get ("max_steps" )
203+ elif not hasattr (dataset , "__len__" ):
204+ raise ValueError (
205+ "max_steps must be specified since the dataset length cannot be"
206+ " determined."
207+ )
208+
209+ dataset_length = len (dataset )
210+
211+ batch_size = self .config .get ("batch_size" , 1 )
212+ num_batches = self .config .get ("num_batches" )
213+ if not num_batches :
214+ num_batches = dataset_length // batch_size
215+ logging .info (
216+ "Dynamically computed num_batches=%d with batch_size=%d" ,
217+ num_batches ,
218+ batch_size ,
219+ )
220+ num_train_epochs = self .config .get ("num_train_epochs" )
221+ if not num_train_epochs :
222+ num_train_epochs = 1
223+
224+ train_fraction = self .config .get ("train_fraction" )
225+ if not train_fraction :
226+ train_fraction = 0.8
227+ elif train_fraction <= 0.0 and train_fraction > 1.0 :
228+ logging .warning (
229+ f"train_fraction { train_fraction :.2f} out of expected range. Setting"
230+ " to 0.8"
231+ )
232+ train_fraction = 0.8
233+
234+ allowed_max_steps = int (num_batches * num_train_epochs * train_fraction )
235+ if not max_steps :
236+ max_steps = allowed_max_steps
237+ elif max_steps > allowed_max_steps :
238+ raise ValueError (
239+ "Maximum allowed value for max_steps is %d" , allowed_max_steps
240+ )
241+
242+ rl_training_config ["max_steps" ] = max_steps
243+ actor_opt : dict [str , Any ] = rl_training_config .get (
244+ "actor_optimizer_config" , {}
245+ )
246+ if actor_opt and not actor_opt .get ("decay_steps" ):
247+ actor_opt ["decay_steps" ] = max_steps
248+ if actor_opt and not actor_opt .get ("warmup_steps" ):
249+ warmup_ratio = self .config .get ("warmup_ratio" , 0.1 )
250+ warmup_steps = self .config .get ("warmup_steps" , warmup_ratio * max_steps )
251+ actor_opt ["warmup_steps" ] = warmup_steps
252+ logging .info (
253+ "Dynamically computed max_steps=%d based on dataset length %d" ,
254+ max_steps ,
255+ dataset_length ,
256+ )
257+
197258 def run_grpo_trainer (self ):
198- grpo_trainer = grpo_learner .GrpoLearner (
199- rl_cluster = self .create_rl_cluster (),
200- reward_fns = self .obtain_reward_fn (),
201- algo_config = GrpoConfig (** self .config ["grpo_config" ]),
259+ tokenizer = model_lib .create_tokenizer (
260+ self .config ["tokenizer_config" ],
261+ self .config ["tokenizer_config" ]["tokenizer_path" ],
202262 )
203263
204- tokenizer = grpo_trainer .rl_cluster .tokenizer
205264 if self .config .get ("data_module" , None ):
206265 dataset = data_lib .get_dataset_from_module (
207266 self .config ["data_module" ],
@@ -219,6 +278,7 @@ def run_grpo_trainer(self):
219278 dataset = self .config ["dataset_name" ],
220279 tfds_download = self .config ["tfds_download" ],
221280 )
281+ self .compute_params (dataset )
222282 dataset , _ = data_lib .post_init_dataset (
223283 dataset ,
224284 tokenizer ,
@@ -228,6 +288,12 @@ def run_grpo_trainer(self):
228288 "max_prompt_length" , None
229289 ),
230290 )
291+ rl_cluster = self .create_rl_cluster (tokenizer )
292+ grpo_trainer = grpo_learner .GrpoLearner (
293+ rl_cluster = rl_cluster ,
294+ reward_fns = self .obtain_reward_fn (),
295+ algo_config = GrpoConfig (** self .config ["grpo_config" ]),
296+ )
231297 grpo_trainer .train (dataset )
232298
233299
0 commit comments