Add support for dynamically setting the number of steps for GRPO.#1257
Add support for dynamically setting the number of steps for GRPO.#1257niting wants to merge 1 commit intogoogle:mainfrom
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
tunix/cli/grpo_main.py
Outdated
| dataset=self.config["dataset_name"], | ||
| tfds_download=self.config["tfds_download"], | ||
| ) | ||
| self.compute_params(len(dataset)) |
There was a problem hiding this comment.
I believe not all the dataset implements len and that's why we might have to rely on the config to provide the accurate length if we don't want to go through the dataset once.
There was a problem hiding this comment.
That makes sense; if the len(..) is not implemented, and num_steps is not specified, then we should throw an error, but if dataset.len exists, then we should allow to not specify num_steps?
There was a problem hiding this comment.
Pretty sure more things will break if len was not implemented. See tunix/cli/utils/data.py:177 which splits the train and test sets. It would be odd for a dataset to not have that implemented since they are typically just iterator types.
There was a problem hiding this comment.
grain supports datasets without len()
There was a problem hiding this comment.
Done. I now check if len is available and enforce that max_steps is required when it's not. Note that post_init_dataset in tunix/cli/utils/data.py will still break if len is not available, can fix that separately since that's unrelated to this PR.
There was a problem hiding this comment.
Thank you! The fix makes sense to me.
| train_fraction = self.config.get("train_fraction") | ||
| if not train_fraction: | ||
| train_fraction = 0.8 | ||
| if not max_steps: |
There was a problem hiding this comment.
Shall we check the max_steps against int(num_batches * num_train_epochs * train_fraction) if max_steps is available?
There was a problem hiding this comment.
I can, but I was assuming that the user might specify max_steps when they want to really try out the behavior with different steps. I could potentially cap the max_steps to that value or just leave it as is for now. What do you prefer?
| rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \ | ||
| rl_training_config.actor_optimizer_config.init_value=0.0 \ | ||
| rl_training_config.actor_optimizer_config.end_value=0.0 \ | ||
| rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \ |
There was a problem hiding this comment.
Maybe we should still setup warm_up ratio to 0.1 instead of relying on the default value?
There was a problem hiding this comment.
Done. Reverted this change.
examples/rl/grpo/gsm8k/run_qwen3.sh
Outdated
| batch_size=${batch_size:-8} | ||
| num_train_epochs=${num_train_epochs:-1} | ||
| warmup_ratio=${warmup_ratio:-0.1} | ||
| train_fraction=${train_fraction:-1.0} |
There was a problem hiding this comment.
I think we should set train_fraction? The default value is 0.8.
There was a problem hiding this comment.
The train_fraction was 1.0. I updated it to 0.8.
244a257 to
84827a6
Compare
The existing implementation requires these to be specified by the user. We want users to be able to point to their dataset and the implementation should identify the length of dataset. The dataset length is then used to adjust the number of steps required provided the batch size. Updates the Qwen script to use the feature.
84827a6 to
9ea04cb
Compare
The existing implementation requires these to be specified by the user. We want users to be able to point to their dataset and the implementation should identify the length of dataset. The dataset length is then used to adjust the number of steps required provided the batch size.
Updates the Qwen script to use the feature.
Reference
Colab Notebook
Checklist
This change has been tested locally by doing a GRPO run and running the Qwen script.