@@ -100,7 +100,7 @@ def get_maxtext_model(config, devices=None):
100100 # Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e.,
101101 # load_parameters_path=/path/to/your/output/directory/0/items
102102 """
103- model , mesh = model_creation_utils .create_nnx_model (config , devices )
103+ model , mesh = model_creation_utils .create_nnx_model (config , devices = devices )
104104 with mesh :
105105 tunix_model = TunixMaxTextAdapter (base_model = model )
106106 tunix_model .config = None
@@ -238,7 +238,7 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
238238 trainer_config .num_batches
239239 * trainer_config .num_iterations
240240 * trainer_config .train_fraction
241- * trainer_config .num_epochs
241+ * trainer_config .num_epoch
242242 )
243243
244244 # ====== Data ======
@@ -260,10 +260,10 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
260260 )[: trainer_config .num_batches ]
261261
262262 if trainer_config .train_fraction == 1.0 :
263- train_dataset = dataset .repeat (trainer_config .num_epochs )
263+ train_dataset = dataset .repeat (trainer_config .num_epoch )
264264 else :
265265 train_dataset = dataset [: int (len (dataset ) * trainer_config .train_fraction )]
266- train_dataset = train_dataset .repeat (trainer_config .num_epochs )
266+ train_dataset = train_dataset .repeat (trainer_config .num_epoch )
267267
268268 test_dataset = get_dataset (model_tokenizer , trainer_config , test_data_dir , trainer_config .eval_split ).batch (
269269 trainer_config .batch_size
@@ -416,7 +416,7 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
416416 lambda ** kwargs : utils_rl .check_answer (tmvp_config = trainer_config , ** kwargs ),
417417 lambda ** kwargs : utils_rl .check_numbers (tmvp_config = trainer_config , ** kwargs ),
418418 ],
419- grpo_config = grpo_config ,
419+ algo_config = grpo_config ,
420420 )
421421
422422 # Before we train the model, let's evaluate the model on the test set so we can
0 commit comments