Skip to content

Add support for dynamically setting the number of steps for GRPO.#1257

Open
niting wants to merge 1 commit intogoogle:mainfrom
niting:niting/compute_steps
Open

Add support for dynamically setting the number of steps for GRPO.#1257
niting wants to merge 1 commit intogoogle:mainfrom
niting:niting/compute_steps

Conversation

@niting
Copy link
Contributor

@niting niting commented Mar 17, 2026

The existing implementation requires these to be specified by the user. We want users to be able to point to their dataset and the implementation should identify the length of dataset. The dataset length is then used to adjust the number of steps required provided the batch size.

Updates the Qwen script to use the feature.

It's a good idea to open an issue first for discussion.

Reference

Colab Notebook

Checklist

This change has been tested locally by doing a GRPO run and running the Qwen script.

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and all unit tests pass.
  • I have added all appropriate doc-strings/documentation.
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have signed the Contributor License Agreement.
  • I have followed Contribution Guidelines.

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

dataset=self.config["dataset_name"],
tfds_download=self.config["tfds_download"],
)
self.compute_params(len(dataset))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe not all the dataset implements len and that's why we might have to rely on the config to provide the accurate length if we don't want to go through the dataset once.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense; if the len(..) is not implemented, and num_steps is not specified, then we should throw an error, but if dataset.len exists, then we should allow to not specify num_steps?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty sure more things will break if len was not implemented. See tunix/cli/utils/data.py:177 which splits the train and test sets. It would be odd for a dataset to not have that implemented since they are typically just iterator types.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grain supports datasets without len()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I now check if len is available and enforce that max_steps is required when it's not. Note that post_init_dataset in tunix/cli/utils/data.py will still break if len is not available, can fix that separately since that's unrelated to this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! The fix makes sense to me.

train_fraction = self.config.get("train_fraction")
if not train_fraction:
train_fraction = 0.8
if not max_steps:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we check the max_steps against int(num_batches * num_train_epochs * train_fraction) if max_steps is available?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can, but I was assuming that the user might specify max_steps when they want to really try out the behavior with different steps. I could potentially cap the max_steps to that value or just leave it as is for now. What do you prefer?

rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \
rl_training_config.actor_optimizer_config.init_value=0.0 \
rl_training_config.actor_optimizer_config.end_value=0.0 \
rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should still setup warm_up ratio to 0.1 instead of relying on the default value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Reverted this change.

batch_size=${batch_size:-8}
num_train_epochs=${num_train_epochs:-1}
warmup_ratio=${warmup_ratio:-0.1}
train_fraction=${train_fraction:-1.0}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should set train_fraction? The default value is 0.8.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The train_fraction was 1.0. I updated it to 0.8.

@niting niting force-pushed the niting/compute_steps branch 2 times, most recently from 244a257 to 84827a6 Compare March 21, 2026 22:55
The existing implementation requires these to be specified by the user.
We want users to be able to point to their dataset and the
implementation should identify the length of dataset. The dataset length
is then used to adjust the number of steps required provided the batch
size.

Updates the Qwen script to use the feature.
@niting niting force-pushed the niting/compute_steps branch from 84827a6 to 9ea04cb Compare March 23, 2026 01:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants