Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 2 additions & 15 deletions examples/rl/grpo/gsm8k/run_qwen3.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,17 @@ set -x # Enable xtrace

# specify at cmd line to override defaults, e.g.
model_name=${model_name:-"Qwen3-1.7B-base"}
batch_size=${batch_size:-1}
num_batches=${num_batches:-3738}
batch_size=${batch_size:-8}
num_train_epochs=${num_train_epochs:-1}
warmup_ratio=${warmup_ratio:-0.1}
train_fraction=${train_fraction:-1.0}
train_fraction=${train_fraction:-0.8}

echo "Using parameters:"
echo " Batch Size: $batch_size"
echo " Num Batches: $num_batches"
echo " Num Epochs: $num_train_epochs"
echo " Warmup Ratio: $warmup_ratio"
echo " Train Fraction: $train_fraction"

max_steps_float=$(awk "BEGIN {print $batch_size * $num_batches * $num_train_epochs * $train_fraction}")
max_steps=$(printf "%.0f" "$max_steps_float")
warmup_steps=$(awk "BEGIN {printf \"%.0f\", $warmup_ratio * $max_steps}")

echo "Max steps: $max_steps"
echo "Rounded warmup steps: $warmup_steps"

python3 -m tunix.cli.grpo_main \
base_config.yaml \
model_config.model_name=${model_name} \
Expand All @@ -58,7 +49,6 @@ python3 -m tunix.cli.grpo_main \
tokenizer_config.add_bos=false \
dataset_name="gsm8k" \
batch_size=$batch_size \
num_batches=$num_batches \
num_test_batches=100 \
num_train_epochs=$num_train_epochs \
rl_training_config.actor_optimizer_config.opt_type="adamw" \
Expand All @@ -67,14 +57,11 @@ python3 -m tunix.cli.grpo_main \
rl_training_config.actor_optimizer_config.init_value=0.0 \
rl_training_config.actor_optimizer_config.end_value=0.0 \
rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should still setup warm_up ratio to 0.1 instead of relying on the default value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Reverted this change.

rl_training_config.actor_optimizer_config.warmup_steps=$warmup_steps \
rl_training_config.actor_optimizer_config.decay_steps=$max_steps \
rl_training_config.actor_optimizer_config.b1=0.9 \
rl_training_config.actor_optimizer_config.b2=0.99 \
rl_training_config.actor_optimizer_config.weight_decay=0.1 \
rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \
rl_training_config.eval_every_n_steps=10 \
rl_training_config.max_steps=$max_steps \
rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/${model_name}" \
rl_training_config.metrics_logging_options.flush_every_n_steps=20 \
rl_training_config.checkpointing_options.save_interval_steps=500 \
Expand Down
3 changes: 1 addition & 2 deletions tunix/cli/base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ data_directory: ""
data_module: ""
dataset_name: "Helsinki-NLP/opus-100"
batch_size: 16
num_batches: 3738
max_target_length: 256
num_train_epochs: 1
train_fraction: 1.0
Expand Down Expand Up @@ -210,4 +209,4 @@ offload_to_cpu: false

verl_compatible: false
reward_functions:
- tunix/cli/reward_fn/gsm8k.py
- tunix/cli/reward_fn/gsm8k.py
76 changes: 66 additions & 10 deletions tunix/cli/grpo_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def create_perf_config(self, cluster_config: rl_cluster_lib.ClusterConfig):
).export_metrics
return perf_config

def create_rl_cluster(self):
def create_rl_cluster(self, tokenizer):
# Should not use LoRA for reference model.
if self.config["reference_model_config"].get("lora_config"):
logging.warning(
Expand All @@ -177,10 +177,6 @@ def create_rl_cluster(self):
jax.tree.map(jnp.copy, params),
)

tokenizer = model_lib.create_tokenizer(
self.config["tokenizer_config"], tokenizer_path
)

cluster_config = self.create_cluster_config()
perf_config = self.create_perf_config(cluster_config)
return rl_cluster_lib.RLCluster(
Expand All @@ -191,14 +187,67 @@ def create_rl_cluster(self):
perf_config=perf_config,
)

def compute_params(self, dataset):
rl_training_config = self.config.get("rl_training_config", {})

# Return early if max_steps is already specified.
max_steps = None
if rl_training_config.get("max_steps"):
max_steps = rl_training_config.get("max_steps")
elif not hasattr(dataset, "__len__"):
raise ValueError(
"max_steps must be specified since the dataset length cannot be"
" determined."
)

dataset_length = len(dataset)

batch_size = self.config.get("batch_size", 1)
num_batches = self.config.get("num_batches")
if not num_batches:
num_batches = dataset_length // batch_size
logging.info(
"Dynamically computed num_batches=%d with batch_size=%d",
num_batches,
batch_size,
)
num_train_epochs = self.config.get("num_train_epochs")
if not num_train_epochs:
num_train_epochs = 1

train_fraction = self.config.get("train_fraction")
if not train_fraction:
train_fraction = 0.8
elif train_fraction <= 0.0 and train_fraction > 1.0:
logging.warning(
f"train_fraction {train_fraction:.2f} out of expected range. Setting"
" to 0.8"
)
train_fraction = 0.8

if not max_steps:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we check the max_steps against int(num_batches * num_train_epochs * train_fraction) if max_steps is available?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can, but I was assuming that the user might specify max_steps when they want to really try out the behavior with different steps. I could potentially cap the max_steps to that value or just leave it as is for now. What do you prefer?

max_steps = int(num_batches * num_train_epochs * train_fraction)

rl_training_config["max_steps"] = max_steps
actor_opt = rl_training_config.get("actor_optimizer_config", {})
if actor_opt and not actor_opt.get("decay_steps"):
actor_opt["decay_steps"] = max_steps
if actor_opt and not actor_opt.get("warmup_steps"):
warmup_ratio = self.config.get("warmup_ratio", 0.1)
warmup_steps = self.config.get("warmup_steps", warmup_ratio * max_steps)
actor_opt["warmup_steps"] = warmup_steps
logging.info(
"Dynamically computed max_steps=%d based on dataset length %d",
max_steps,
dataset_length,
)

def run_grpo_trainer(self):
grpo_trainer = grpo_learner.GrpoLearner(
rl_cluster=self.create_rl_cluster(),
reward_fns=self.obtain_reward_fn(),
algo_config=GrpoConfig(**self.config["grpo_config"]),
tokenizer = model_lib.create_tokenizer(
self.config["tokenizer_config"],
self.config["tokenizer_config"]["tokenizer_path"],
)

tokenizer = grpo_trainer.rl_cluster.tokenizer
if self.config.get("data_module", None):
dataset = data_lib.get_dataset_from_module(
self.config["data_module"],
Expand All @@ -216,6 +265,7 @@ def run_grpo_trainer(self):
dataset=self.config["dataset_name"],
tfds_download=self.config["tfds_download"],
)
self.compute_params(dataset)
dataset, _ = data_lib.post_init_dataset(
dataset,
tokenizer,
Expand All @@ -225,6 +275,12 @@ def run_grpo_trainer(self):
"max_prompt_length", None
),
)
rl_cluster = self.create_rl_cluster(tokenizer)
grpo_trainer = grpo_learner.GrpoLearner(
rl_cluster=rl_cluster,
reward_fns=self.obtain_reward_fn(),
algo_config=GrpoConfig(**self.config["grpo_config"]),
)
grpo_trainer.train(dataset)


Expand Down