@@ -151,7 +151,7 @@ def create_perf_config(self, cluster_config: rl_cluster_lib.ClusterConfig):
151151 ).export_metrics
152152 return perf_config
153153
154- def create_rl_cluster (self ):
154+ def create_rl_cluster (self , tokenizer ):
155155 # Should not use LoRA for reference model.
156156 if self .config ["reference_model_config" ].get ("lora_config" ):
157157 logging .warning (
@@ -177,10 +177,6 @@ def create_rl_cluster(self):
177177 jax .tree .map (jnp .copy , params ),
178178 )
179179
180- tokenizer = model_lib .create_tokenizer (
181- self .config ["tokenizer_config" ], tokenizer_path
182- )
183-
184180 cluster_config = self .create_cluster_config ()
185181 perf_config = self .create_perf_config (cluster_config )
186182 return rl_cluster_lib .RLCluster (
@@ -191,14 +187,67 @@ def create_rl_cluster(self):
191187 perf_config = perf_config ,
192188 )
193189
190+ def compute_params (self , dataset ):
191+ rl_training_config = self .config .get ("rl_training_config" , {})
192+
193+ # Return early if max_steps is already specified.
194+ max_steps = None
195+ if rl_training_config .get ("max_steps" ):
196+ max_steps = rl_training_config .get ("max_steps" )
197+ elif not hasattr (dataset , "__len__" ):
198+ raise ValueError (
199+ "max_steps must be specified since the dataset length cannot be"
200+ " determined."
201+ )
202+
203+ dataset_length = len (dataset )
204+
205+ batch_size = self .config .get ("batch_size" , 1 )
206+ num_batches = self .config .get ("num_batches" )
207+ if not num_batches :
208+ num_batches = dataset_length // batch_size
209+ logging .info (
210+ "Dynamically computed num_batches=%d with batch_size=%d" ,
211+ num_batches ,
212+ batch_size ,
213+ )
214+ num_train_epochs = self .config .get ("num_train_epochs" )
215+ if not num_train_epochs :
216+ num_train_epochs = 1
217+
218+ train_fraction = self .config .get ("train_fraction" )
219+ if not train_fraction :
220+ train_fraction = 0.8
221+ elif train_fraction <= 0.0 and train_fraction > 1.0 :
222+ logging .warning (
223+ f"train_fraction { train_fraction :.2f} out of expected range. Setting"
224+ " to 0.8"
225+ )
226+ train_fraction = 0.8
227+
228+ if not max_steps :
229+ max_steps = int (num_batches * num_train_epochs * train_fraction )
230+
231+ rl_training_config ["max_steps" ] = max_steps
232+ actor_opt = rl_training_config .get ("actor_optimizer_config" , {})
233+ if actor_opt and not actor_opt .get ("decay_steps" ):
234+ actor_opt ["decay_steps" ] = max_steps
235+ if actor_opt and not actor_opt .get ("warmup_steps" ):
236+ warmup_ratio = self .config .get ("warmup_ratio" , 0.1 )
237+ warmup_steps = self .config .get ("warmup_steps" , warmup_ratio * max_steps )
238+ actor_opt ["warmup_steps" ] = warmup_steps
239+ logging .info (
240+ "Dynamically computed max_steps=%d based on dataset length %d" ,
241+ max_steps ,
242+ dataset_length ,
243+ )
244+
194245 def run_grpo_trainer (self ):
195- grpo_trainer = grpo_learner .GrpoLearner (
196- rl_cluster = self .create_rl_cluster (),
197- reward_fns = self .obtain_reward_fn (),
198- algo_config = GrpoConfig (** self .config ["grpo_config" ]),
246+ tokenizer = model_lib .create_tokenizer (
247+ self .config ["tokenizer_config" ],
248+ self .config ["tokenizer_config" ]["tokenizer_path" ],
199249 )
200250
201- tokenizer = grpo_trainer .rl_cluster .tokenizer
202251 if self .config .get ("data_module" , None ):
203252 dataset = data_lib .get_dataset_from_module (
204253 self .config ["data_module" ],
@@ -216,6 +265,7 @@ def run_grpo_trainer(self):
216265 dataset = self .config ["dataset_name" ],
217266 tfds_download = self .config ["tfds_download" ],
218267 )
268+ self .compute_params (dataset )
219269 dataset , _ = data_lib .post_init_dataset (
220270 dataset ,
221271 tokenizer ,
@@ -225,6 +275,12 @@ def run_grpo_trainer(self):
225275 "max_prompt_length" , None
226276 ),
227277 )
278+ rl_cluster = self .create_rl_cluster (tokenizer )
279+ grpo_trainer = grpo_learner .GrpoLearner (
280+ rl_cluster = rl_cluster ,
281+ reward_fns = self .obtain_reward_fn (),
282+ algo_config = GrpoConfig (** self .config ["grpo_config" ]),
283+ )
228284 grpo_trainer .train (dataset )
229285
230286
0 commit comments