-
Notifications
You must be signed in to change notification settings - Fork 262
Add support for dynamically setting the number of steps for GRPO. #1257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -151,7 +151,7 @@ def create_perf_config(self, cluster_config: rl_cluster_lib.ClusterConfig): | |
| ).export_metrics | ||
| return perf_config | ||
|
|
||
| def create_rl_cluster(self): | ||
| def create_rl_cluster(self, tokenizer): | ||
| # Should not use LoRA for reference model. | ||
| if self.config["reference_model_config"].get("lora_config"): | ||
| logging.warning( | ||
|
|
@@ -177,10 +177,6 @@ def create_rl_cluster(self): | |
| jax.tree.map(jnp.copy, params), | ||
| ) | ||
|
|
||
| tokenizer = model_lib.create_tokenizer( | ||
| self.config["tokenizer_config"], tokenizer_path | ||
| ) | ||
|
|
||
| cluster_config = self.create_cluster_config() | ||
| perf_config = self.create_perf_config(cluster_config) | ||
| return rl_cluster_lib.RLCluster( | ||
|
|
@@ -191,14 +187,67 @@ def create_rl_cluster(self): | |
| perf_config=perf_config, | ||
| ) | ||
|
|
||
| def compute_params(self, dataset): | ||
| rl_training_config = self.config.get("rl_training_config", {}) | ||
|
|
||
| # Return early if max_steps is already specified. | ||
| max_steps = None | ||
| if rl_training_config.get("max_steps"): | ||
| max_steps = rl_training_config.get("max_steps") | ||
| elif not hasattr(dataset, "__len__"): | ||
| raise ValueError( | ||
| "max_steps must be specified since the dataset length cannot be" | ||
| " determined." | ||
| ) | ||
|
|
||
| dataset_length = len(dataset) | ||
|
|
||
| batch_size = self.config.get("batch_size", 1) | ||
| num_batches = self.config.get("num_batches") | ||
| if not num_batches: | ||
| num_batches = dataset_length // batch_size | ||
| logging.info( | ||
| "Dynamically computed num_batches=%d with batch_size=%d", | ||
| num_batches, | ||
| batch_size, | ||
| ) | ||
| num_train_epochs = self.config.get("num_train_epochs") | ||
| if not num_train_epochs: | ||
| num_train_epochs = 1 | ||
|
|
||
| train_fraction = self.config.get("train_fraction") | ||
| if not train_fraction: | ||
wang2yn84 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| train_fraction = 0.8 | ||
| elif train_fraction <= 0.0 and train_fraction > 1.0: | ||
| logging.warning( | ||
| f"train_fraction {train_fraction:.2f} out of expected range. Setting" | ||
| " to 0.8" | ||
| ) | ||
| train_fraction = 0.8 | ||
|
|
||
| if not max_steps: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we check the max_steps against int(num_batches * num_train_epochs * train_fraction) if max_steps is available?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can, but I was assuming that the user might specify max_steps when they want to really try out the behavior with different steps. I could potentially cap the max_steps to that value or just leave it as is for now. What do you prefer? |
||
| max_steps = int(num_batches * num_train_epochs * train_fraction) | ||
|
|
||
| rl_training_config["max_steps"] = max_steps | ||
| actor_opt = rl_training_config.get("actor_optimizer_config", {}) | ||
| if actor_opt and not actor_opt.get("decay_steps"): | ||
| actor_opt["decay_steps"] = max_steps | ||
| if actor_opt and not actor_opt.get("warmup_steps"): | ||
| warmup_ratio = self.config.get("warmup_ratio", 0.1) | ||
| warmup_steps = self.config.get("warmup_steps", warmup_ratio * max_steps) | ||
| actor_opt["warmup_steps"] = warmup_steps | ||
| logging.info( | ||
| "Dynamically computed max_steps=%d based on dataset length %d", | ||
| max_steps, | ||
| dataset_length, | ||
| ) | ||
|
|
||
| def run_grpo_trainer(self): | ||
| grpo_trainer = grpo_learner.GrpoLearner( | ||
| rl_cluster=self.create_rl_cluster(), | ||
| reward_fns=self.obtain_reward_fn(), | ||
| algo_config=GrpoConfig(**self.config["grpo_config"]), | ||
| tokenizer = model_lib.create_tokenizer( | ||
| self.config["tokenizer_config"], | ||
| self.config["tokenizer_config"]["tokenizer_path"], | ||
| ) | ||
|
|
||
| tokenizer = grpo_trainer.rl_cluster.tokenizer | ||
| if self.config.get("data_module", None): | ||
| dataset = data_lib.get_dataset_from_module( | ||
| self.config["data_module"], | ||
|
|
@@ -216,6 +265,7 @@ def run_grpo_trainer(self): | |
| dataset=self.config["dataset_name"], | ||
| tfds_download=self.config["tfds_download"], | ||
| ) | ||
| self.compute_params(dataset) | ||
| dataset, _ = data_lib.post_init_dataset( | ||
| dataset, | ||
| tokenizer, | ||
|
|
@@ -225,6 +275,12 @@ def run_grpo_trainer(self): | |
| "max_prompt_length", None | ||
| ), | ||
| ) | ||
| rl_cluster = self.create_rl_cluster(tokenizer) | ||
| grpo_trainer = grpo_learner.GrpoLearner( | ||
| rl_cluster=rl_cluster, | ||
| reward_fns=self.obtain_reward_fn(), | ||
| algo_config=GrpoConfig(**self.config["grpo_config"]), | ||
| ) | ||
| grpo_trainer.train(dataset) | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should still setup warm_up ratio to 0.1 instead of relying on the default value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Reverted this change.