Skip to content

Commit 84827a6

Browse files
committed
Add support for dynamically setting the number of steps for GRPO.
The existing implementation requires these to be specified by the user. We want users to be able to point to their dataset and the implementation should identify the length of dataset. The dataset length is then used to adjust the number of steps required provided the batch size. Updates the Qwen script to use the feature.
1 parent 8e586d6 commit 84827a6

File tree

3 files changed

+69
-27
lines changed

3 files changed

+69
-27
lines changed

examples/rl/grpo/gsm8k/run_qwen3.sh

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,17 @@ set -x # Enable xtrace
1717

1818
# specify at cmd line to override defaults, e.g.
1919
model_name=${model_name:-"Qwen3-1.7B-base"}
20-
batch_size=${batch_size:-1}
21-
num_batches=${num_batches:-3738}
20+
batch_size=${batch_size:-8}
2221
num_train_epochs=${num_train_epochs:-1}
2322
warmup_ratio=${warmup_ratio:-0.1}
24-
train_fraction=${train_fraction:-1.0}
23+
train_fraction=${train_fraction:-0.8}
2524

2625
echo "Using parameters:"
2726
echo " Batch Size: $batch_size"
28-
echo " Num Batches: $num_batches"
2927
echo " Num Epochs: $num_train_epochs"
3028
echo " Warmup Ratio: $warmup_ratio"
3129
echo " Train Fraction: $train_fraction"
3230

33-
max_steps_float=$(awk "BEGIN {print $batch_size * $num_batches * $num_train_epochs * $train_fraction}")
34-
max_steps=$(printf "%.0f" "$max_steps_float")
35-
warmup_steps=$(awk "BEGIN {printf \"%.0f\", $warmup_ratio * $max_steps}")
36-
37-
echo "Max steps: $max_steps"
38-
echo "Rounded warmup steps: $warmup_steps"
39-
4031
python3 -m tunix.cli.grpo_main \
4132
base_config.yaml \
4233
model_config.model_name=${model_name} \
@@ -58,7 +49,6 @@ python3 -m tunix.cli.grpo_main \
5849
tokenizer_config.add_bos=false \
5950
dataset_name="gsm8k" \
6051
batch_size=$batch_size \
61-
num_batches=$num_batches \
6252
num_test_batches=100 \
6353
num_train_epochs=$num_train_epochs \
6454
rl_training_config.actor_optimizer_config.opt_type="adamw" \
@@ -67,14 +57,11 @@ python3 -m tunix.cli.grpo_main \
6757
rl_training_config.actor_optimizer_config.init_value=0.0 \
6858
rl_training_config.actor_optimizer_config.end_value=0.0 \
6959
rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \
70-
rl_training_config.actor_optimizer_config.warmup_steps=$warmup_steps \
71-
rl_training_config.actor_optimizer_config.decay_steps=$max_steps \
7260
rl_training_config.actor_optimizer_config.b1=0.9 \
7361
rl_training_config.actor_optimizer_config.b2=0.99 \
7462
rl_training_config.actor_optimizer_config.weight_decay=0.1 \
7563
rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \
7664
rl_training_config.eval_every_n_steps=10 \
77-
rl_training_config.max_steps=$max_steps \
7865
rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/${model_name}" \
7966
rl_training_config.metrics_logging_options.flush_every_n_steps=20 \
8067
rl_training_config.checkpointing_options.save_interval_steps=500 \

tunix/cli/base_config.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ data_directory: ""
9494
data_module: ""
9595
dataset_name: "Helsinki-NLP/opus-100"
9696
batch_size: 16
97-
num_batches: 3738
9897
max_target_length: 256
9998
num_train_epochs: 1
10099
train_fraction: 1.0
@@ -210,4 +209,4 @@ offload_to_cpu: false
210209

211210
verl_compatible: false
212211
reward_functions:
213-
- tunix/cli/reward_fn/gsm8k.py
212+
- tunix/cli/reward_fn/gsm8k.py

tunix/cli/grpo_main.py

Lines changed: 66 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def create_perf_config(self, cluster_config: rl_cluster_lib.ClusterConfig):
151151
).export_metrics
152152
return perf_config
153153

154-
def create_rl_cluster(self):
154+
def create_rl_cluster(self, tokenizer):
155155
# Should not use LoRA for reference model.
156156
if self.config["reference_model_config"].get("lora_config"):
157157
logging.warning(
@@ -177,10 +177,6 @@ def create_rl_cluster(self):
177177
jax.tree.map(jnp.copy, params),
178178
)
179179

180-
tokenizer = model_lib.create_tokenizer(
181-
self.config["tokenizer_config"], tokenizer_path
182-
)
183-
184180
cluster_config = self.create_cluster_config()
185181
perf_config = self.create_perf_config(cluster_config)
186182
return rl_cluster_lib.RLCluster(
@@ -191,14 +187,67 @@ def create_rl_cluster(self):
191187
perf_config=perf_config,
192188
)
193189

190+
def compute_params(self, dataset):
191+
rl_training_config = self.config.get("rl_training_config", {})
192+
193+
# Return early if max_steps is already specified.
194+
max_steps = None
195+
if rl_training_config.get("max_steps"):
196+
max_steps = rl_training_config.get("max_steps")
197+
elif not hasattr(dataset, "__len__"):
198+
raise ValueError(
199+
"max_steps must be specified since the dataset length cannot be"
200+
" determined."
201+
)
202+
203+
dataset_length = len(dataset)
204+
205+
batch_size = self.config.get("batch_size", 1)
206+
num_batches = self.config.get("num_batches")
207+
if not num_batches:
208+
num_batches = dataset_length // batch_size
209+
logging.info(
210+
"Dynamically computed num_batches=%d with batch_size=%d",
211+
num_batches,
212+
batch_size,
213+
)
214+
num_train_epochs = self.config.get("num_train_epochs")
215+
if not num_train_epochs:
216+
num_train_epochs = 1
217+
218+
train_fraction = self.config.get("train_fraction")
219+
if not train_fraction:
220+
train_fraction = 0.8
221+
elif train_fraction <= 0.0 and train_fraction > 1.0:
222+
logging.warning(
223+
f"train_fraction {train_fraction:.2f} out of expected range. Setting"
224+
" to 0.8"
225+
)
226+
train_fraction = 0.8
227+
228+
if not max_steps:
229+
max_steps = int(num_batches * num_train_epochs * train_fraction)
230+
231+
rl_training_config["max_steps"] = max_steps
232+
actor_opt = rl_training_config.get("actor_optimizer_config", {})
233+
if actor_opt and not actor_opt.get("decay_steps"):
234+
actor_opt["decay_steps"] = max_steps
235+
if actor_opt and not actor_opt.get("warmup_steps"):
236+
warmup_ratio = self.config.get("warmup_ratio", 0.1)
237+
warmup_steps = self.config.get("warmup_steps", warmup_ratio * max_steps)
238+
actor_opt["warmup_steps"] = warmup_steps
239+
logging.info(
240+
"Dynamically computed max_steps=%d based on dataset length %d",
241+
max_steps,
242+
dataset_length,
243+
)
244+
194245
def run_grpo_trainer(self):
195-
grpo_trainer = grpo_learner.GrpoLearner(
196-
rl_cluster=self.create_rl_cluster(),
197-
reward_fns=self.obtain_reward_fn(),
198-
algo_config=GrpoConfig(**self.config["grpo_config"]),
246+
tokenizer = model_lib.create_tokenizer(
247+
self.config["tokenizer_config"],
248+
self.config["tokenizer_config"]["tokenizer_path"],
199249
)
200250

201-
tokenizer = grpo_trainer.rl_cluster.tokenizer
202251
if self.config.get("data_module", None):
203252
dataset = data_lib.get_dataset_from_module(
204253
self.config["data_module"],
@@ -216,6 +265,7 @@ def run_grpo_trainer(self):
216265
dataset=self.config["dataset_name"],
217266
tfds_download=self.config["tfds_download"],
218267
)
268+
self.compute_params(dataset)
219269
dataset, _ = data_lib.post_init_dataset(
220270
dataset,
221271
tokenizer,
@@ -225,6 +275,12 @@ def run_grpo_trainer(self):
225275
"max_prompt_length", None
226276
),
227277
)
278+
rl_cluster = self.create_rl_cluster(tokenizer)
279+
grpo_trainer = grpo_learner.GrpoLearner(
280+
rl_cluster=rl_cluster,
281+
reward_fns=self.obtain_reward_fn(),
282+
algo_config=GrpoConfig(**self.config["grpo_config"]),
283+
)
228284
grpo_trainer.train(dataset)
229285

230286

0 commit comments

Comments
 (0)