Skip to content

Commit bf94583

Browse files
author
The tunix Authors
committed
Merge pull request #1257 from niting:niting/compute_steps
PiperOrigin-RevId: 889967078
2 parents a36f3fa + bdc940f commit bf94583

File tree

3 files changed

+79
-27
lines changed

3 files changed

+79
-27
lines changed

examples/rl/grpo/gsm8k/run_qwen3.sh

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,17 @@ set -x # Enable xtrace
1717

1818
# specify at cmd line to override defaults, e.g.
1919
model_name=${model_name:-"Qwen3-1.7B-base"}
20-
batch_size=${batch_size:-1}
21-
num_batches=${num_batches:-3738}
20+
batch_size=${batch_size:-8}
2221
num_train_epochs=${num_train_epochs:-1}
2322
warmup_ratio=${warmup_ratio:-0.1}
24-
train_fraction=${train_fraction:-1.0}
23+
train_fraction=${train_fraction:-0.8}
2524

2625
echo "Using parameters:"
2726
echo " Batch Size: $batch_size"
28-
echo " Num Batches: $num_batches"
2927
echo " Num Epochs: $num_train_epochs"
3028
echo " Warmup Ratio: $warmup_ratio"
3129
echo " Train Fraction: $train_fraction"
3230

33-
max_steps_float=$(awk "BEGIN {print $batch_size * $num_batches * $num_train_epochs * $train_fraction}")
34-
max_steps=$(printf "%.0f" "$max_steps_float")
35-
warmup_steps=$(awk "BEGIN {printf \"%.0f\", $warmup_ratio * $max_steps}")
36-
37-
echo "Max steps: $max_steps"
38-
echo "Rounded warmup steps: $warmup_steps"
39-
4031
python3 -m tunix.cli.grpo_main \
4132
base_config.yaml \
4233
model_config.model_name=${model_name} \
@@ -58,7 +49,6 @@ python3 -m tunix.cli.grpo_main \
5849
tokenizer_config.add_bos=false \
5950
dataset_name="gsm8k" \
6051
batch_size=$batch_size \
61-
num_batches=$num_batches \
6252
num_test_batches=100 \
6353
num_train_epochs=$num_train_epochs \
6454
rl_training_config.actor_optimizer_config.opt_type="adamw" \
@@ -67,14 +57,11 @@ python3 -m tunix.cli.grpo_main \
6757
rl_training_config.actor_optimizer_config.init_value=0.0 \
6858
rl_training_config.actor_optimizer_config.end_value=0.0 \
6959
rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \
70-
rl_training_config.actor_optimizer_config.warmup_steps=$warmup_steps \
71-
rl_training_config.actor_optimizer_config.decay_steps=$max_steps \
7260
rl_training_config.actor_optimizer_config.b1=0.9 \
7361
rl_training_config.actor_optimizer_config.b2=0.99 \
7462
rl_training_config.actor_optimizer_config.weight_decay=0.1 \
7563
rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \
7664
rl_training_config.eval_every_n_steps=10 \
77-
rl_training_config.max_steps=$max_steps \
7865
rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/${model_name}" \
7966
rl_training_config.metrics_logging_options.flush_every_n_steps=20 \
8067
rl_training_config.checkpointing_options.save_interval_steps=500 \

tunix/cli/base_config.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ data_directory: ""
9494
data_module: ""
9595
dataset_name: "Helsinki-NLP/opus-100"
9696
batch_size: 16
97-
num_batches: 3738
9897
max_target_length: 256
9998
num_train_epochs: 1
10099
train_fraction: 1.0
@@ -210,4 +209,4 @@ offload_to_cpu: false
210209

211210
verl_compatible: false
212211
reward_functions:
213-
- tunix/cli/reward_fn/gsm8k.py
212+
- tunix/cli/reward_fn/gsm8k.py

tunix/cli/grpo_main.py

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from tunix.rl import rl_cluster as rl_cluster_lib
3232
from tunix.rl.grpo import grpo_learner
3333
from tunix.rl.rollout import base_rollout
34+
from typing import Any
3435

3536
GrpoConfig = grpo_learner.GrpoConfig
3637

@@ -154,7 +155,7 @@ def create_perf_config(self, cluster_config: rl_cluster_lib.ClusterConfig):
154155
)
155156
return perf_config
156157

157-
def create_rl_cluster(self):
158+
def create_rl_cluster(self, tokenizer):
158159
# Should not use LoRA for reference model.
159160
if self.config["reference_model_config"].get("lora_config"):
160161
logging.warning(
@@ -180,10 +181,6 @@ def create_rl_cluster(self):
180181
jax.tree.map(jnp.copy, params),
181182
)
182183

183-
tokenizer = model_lib.create_tokenizer(
184-
self.config["tokenizer_config"], tokenizer_path
185-
)
186-
187184
cluster_config = self.create_cluster_config()
188185
perf_config = self.create_perf_config(cluster_config)
189186
return rl_cluster_lib.RLCluster(
@@ -194,14 +191,76 @@ def create_rl_cluster(self):
194191
perf_config=perf_config,
195192
)
196193

194+
def compute_params(self, dataset):
195+
rl_training_config: dict[str, Any] = self.config.get(
196+
"rl_training_config", {}
197+
)
198+
199+
# Return early if max_steps is already specified.
200+
max_steps = None
201+
if rl_training_config.get("max_steps"):
202+
max_steps = rl_training_config.get("max_steps")
203+
elif not hasattr(dataset, "__len__"):
204+
raise ValueError(
205+
"max_steps must be specified since the dataset length cannot be"
206+
" determined."
207+
)
208+
209+
dataset_length = len(dataset)
210+
211+
batch_size = self.config.get("batch_size", 1)
212+
num_batches = self.config.get("num_batches")
213+
if not num_batches:
214+
num_batches = dataset_length // batch_size
215+
logging.info(
216+
"Dynamically computed num_batches=%d with batch_size=%d",
217+
num_batches,
218+
batch_size,
219+
)
220+
num_train_epochs = self.config.get("num_train_epochs")
221+
if not num_train_epochs:
222+
num_train_epochs = 1
223+
224+
train_fraction = self.config.get("train_fraction")
225+
if not train_fraction:
226+
train_fraction = 0.8
227+
elif train_fraction <= 0.0 and train_fraction > 1.0:
228+
logging.warning(
229+
f"train_fraction {train_fraction:.2f} out of expected range. Setting"
230+
" to 0.8"
231+
)
232+
train_fraction = 0.8
233+
234+
allowed_max_steps = int(num_batches * num_train_epochs * train_fraction)
235+
if not max_steps:
236+
max_steps = allowed_max_steps
237+
elif max_steps > allowed_max_steps:
238+
raise ValueError(
239+
"Maximum allowed value for max_steps is %d", allowed_max_steps
240+
)
241+
242+
rl_training_config["max_steps"] = max_steps
243+
actor_opt: dict[str, Any] = rl_training_config.get(
244+
"actor_optimizer_config", {}
245+
)
246+
if actor_opt and not actor_opt.get("decay_steps"):
247+
actor_opt["decay_steps"] = max_steps
248+
if actor_opt and not actor_opt.get("warmup_steps"):
249+
warmup_ratio = self.config.get("warmup_ratio", 0.1)
250+
warmup_steps = self.config.get("warmup_steps", warmup_ratio * max_steps)
251+
actor_opt["warmup_steps"] = warmup_steps
252+
logging.info(
253+
"Dynamically computed max_steps=%d based on dataset length %d",
254+
max_steps,
255+
dataset_length,
256+
)
257+
197258
def run_grpo_trainer(self):
198-
grpo_trainer = grpo_learner.GrpoLearner(
199-
rl_cluster=self.create_rl_cluster(),
200-
reward_fns=self.obtain_reward_fn(),
201-
algo_config=GrpoConfig(**self.config["grpo_config"]),
259+
tokenizer = model_lib.create_tokenizer(
260+
self.config["tokenizer_config"],
261+
self.config["tokenizer_config"]["tokenizer_path"],
202262
)
203263

204-
tokenizer = grpo_trainer.rl_cluster.tokenizer
205264
if self.config.get("data_module", None):
206265
dataset = data_lib.get_dataset_from_module(
207266
self.config["data_module"],
@@ -219,6 +278,7 @@ def run_grpo_trainer(self):
219278
dataset=self.config["dataset_name"],
220279
tfds_download=self.config["tfds_download"],
221280
)
281+
self.compute_params(dataset)
222282
dataset, _ = data_lib.post_init_dataset(
223283
dataset,
224284
tokenizer,
@@ -228,6 +288,12 @@ def run_grpo_trainer(self):
228288
"max_prompt_length", None
229289
),
230290
)
291+
rl_cluster = self.create_rl_cluster(tokenizer)
292+
grpo_trainer = grpo_learner.GrpoLearner(
293+
rl_cluster=rl_cluster,
294+
reward_fns=self.obtain_reward_fn(),
295+
algo_config=GrpoConfig(**self.config["grpo_config"]),
296+
)
231297
grpo_trainer.train(dataset)
232298

233299

0 commit comments

Comments
 (0)