Skip to content
Open
10 changes: 9 additions & 1 deletion examples/rl/grpo/gsm8k/run_qwen3.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,22 @@ echo " Batch Size: $batch_size"
echo " Num Epochs: $num_train_epochs"
echo " Warmup Ratio: $warmup_ratio"
echo " Train Fraction: $train_fraction"
echo " Train Split: $train_split"
echo " Eval Split: $eval_split"

python3 -m tunix.cli.grpo_main \
base_config.yaml \
model_config.model_name=${model_name} \
model_config.model_id=Qwen/${model_name} \
model_config.model_source=huggingface \
model_config.model_download_path="/tmp/models/${model_name}" \
model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/${model_name}" \
model_config.mesh.shape="(2,4)" \
model_config.mesh.axis_names="('fsdp','tp')" \
model_config.rng_seed=42 \
actor_model_config.lora_config.rank=64 \
actor_model_config.lora_config.alpha=64.0 \
actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \
actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \
actor_model_config.mesh.shape="(2,4)" \
actor_model_config.mesh.axis_names="('fsdp','tp')" \
rollout_model_config.mesh.shape="(2,4)" \
Expand All @@ -51,6 +54,11 @@ python3 -m tunix.cli.grpo_main \
batch_size=$batch_size \
num_test_batches=100 \
num_train_epochs=$num_train_epochs \
train_split=$train_split \
eval_data_source="tfds" \
eval_dataset_name="gsm8k" \
eval_num_batches=$num_test_batches \
eval_split=$eval_split \
rl_training_config.actor_optimizer_config.opt_type="adamw" \
rl_training_config.actor_optimizer_config.peak_value=3e-6 \
rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \
Expand Down
21 changes: 16 additions & 5 deletions examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,24 @@ set -x # Enable xtrace

# specify at cmd line to override defaults, e.g.
model_name=${model_name:-"Qwen3-1.7B-base"}
batch_size=${batch_size:-1}
num_batches=${num_batches:-3738}
batch_size=${batch_size:-16}
num_batches=${num_batches:-500}
num_test_batches=${num_test_batches:-10}
num_train_epochs=${num_train_epochs:-1}
warmup_ratio=${warmup_ratio:-0.1}
train_fraction=${train_fraction:-1.0}
train_split=${train_split:-"train"}
eval_split=${eval_split:-"test"}

echo "Using parameters:"
echo " Batch Size: $batch_size"
echo " Num Batches: $num_batches"
echo " Num Test Batches: $num_test_batches"
echo " Num Epochs: $num_train_epochs"
echo " Warmup Ratio: $warmup_ratio"
echo " Train Fraction: $train_fraction"
echo " Train Split: $train_split"
echo " Eval Split: $eval_split"

max_steps_float=$(awk "BEGIN {print $batch_size * $num_batches * $num_train_epochs * $train_fraction}")
max_steps=$(printf "%.0f" "$max_steps_float")
Expand All @@ -42,13 +48,14 @@ python3 -m tunix.cli.grpo_main \
model_config.model_name=${model_name} \
model_config.model_id=Qwen/${model_name} \
model_config.model_source=huggingface \
model_config.model_download_path="/tmp/models/${model_name}" \
model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/${model_name}" \
model_config.mesh.shape="(2,4)" \
model_config.mesh.axis_names="('fsdp','tp')" \
model_config.rng_seed=42 \
actor_model_config.lora_config.rank=64 \
actor_model_config.lora_config.alpha=64.0 \
actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \
actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \
actor_model_config.mesh.shape="(2,4)" \
actor_model_config.mesh.axis_names="('fsdp','tp')" \
rollout_model_config.mesh.shape="(2,4)" \
Expand All @@ -59,8 +66,12 @@ python3 -m tunix.cli.grpo_main \
dataset_name="gsm8k" \
batch_size=$batch_size \
num_batches=$num_batches \
num_test_batches=100 \
num_train_epochs=$num_train_epochs \
train_split=$train_split \
eval_data_source="tfds" \
eval_dataset_name="gsm8k" \
eval_num_batches=$num_test_batches \
eval_split=$eval_split \
rl_training_config.actor_optimizer_config.opt_type="adamw" \
rl_training_config.actor_optimizer_config.peak_value=3e-6 \
rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \
Expand All @@ -73,7 +84,7 @@ python3 -m tunix.cli.grpo_main \
rl_training_config.actor_optimizer_config.b2=0.99 \
rl_training_config.actor_optimizer_config.weight_decay=0.1 \
rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \
rl_training_config.eval_every_n_steps=10 \
rl_training_config.eval_every_n_steps=100 \
rl_training_config.max_steps=$max_steps \
rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/${model_name}" \
rl_training_config.metrics_logging_options.flush_every_n_steps=20 \
Expand Down
13 changes: 13 additions & 0 deletions tunix/cli/base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,23 @@ max_target_length: 256
num_train_epochs: 1
train_fraction: 1.0
num_test_batches: 100
# Dataset split for training (e.g., "train", "validation")
train_split: "train"
# Controls the download flag only when using TFDS datasets. If false, the
# data_dir used will be set to `None` and chosen by default by tfds.
tfds_download: True

# Eval dataset configuration (optional)
eval_data_source: ""
eval_data_directory: ""
eval_data_module: ""
eval_dataset_name: ""
eval_num_batches: 100
# Dataset split for evaluation (e.g., "test", "validation")
eval_split: "test"
# Controls the download flag for eval TFDS datasets
eval_tfds_download: True

############################### Optimizer ###############################
# Optimizer config
optimizer_config: &base_optimizer_config
Expand Down
94 changes: 68 additions & 26 deletions tunix/cli/grpo_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from flax import nnx
import jax
import jax.numpy as jnp
from transformers import Any
from tunix.cli import config
from tunix.cli.utils import data as data_lib
from tunix.cli.utils import model as model_lib
Expand Down Expand Up @@ -155,6 +156,55 @@ def create_perf_config(self, cluster_config: rl_cluster_lib.ClusterConfig):
)
return perf_config

def _load_and_init_dataset(self, tokenizer: Any, *, data_config_prefix: str = "", split: str = "train"):
"""Helper function to load and initialize a dataset.

Args:
tokenizer: Tokenizer for processing.
data_config_prefix: Optional prefix for config keys (e.g., "eval_" for
eval dataset).
split: The dataset split to use (e.g., "train", "test"). Only applies to
TFDS datasets.

Returns:
Initialized dataset ready for training.
"""
def get_key(suffix):
return f"{data_config_prefix}{suffix}"

# Load dataset
if self.config.get(get_key("data_module")):
dataset = data_lib.get_dataset_from_module(
self.config[get_key("data_module")],
tokenizer,
)
elif self.config[get_key("data_source")] == "local":
dataset = example_data.create_dataset(
data_source=self.config[get_key("data_source")],
dataset=self.config[get_key("data_directory")],
tokenizer=tokenizer,
)
else:
dataset = example_data.create_dataset(
data_source="tfds",
dataset=self.config[get_key("dataset_name")],
tfds_download=self.config.get(get_key("tfds_download"), False),
split=split,
)

# Post-initialize dataset
dataset, _ = data_lib.post_init_dataset(
dataset,
tokenizer,
batch_size=self.config["batch_size"],
num_batches=self.config.get(get_key("num_batches"), None),
max_prompt_length=self.config["rollout_config"].get(
"max_prompt_length", None
),
)

return dataset

def create_rl_cluster(self, tokenizer):
# Should not use LoRA for reference model.
if self.config["reference_model_config"].get("lora_config"):
Expand Down Expand Up @@ -260,41 +310,33 @@ def run_grpo_trainer(self):
self.config["tokenizer_config"],
self.config["tokenizer_config"]["tokenizer_path"],
)
tokenizer = grpo_trainer.rl_cluster.tokenizer
# Get dataset splits from config with defaults
train_split = self.config.get("train_split", "train")
eval_split = self.config.get("eval_split", "test")

if self.config.get("data_module", None):
dataset = data_lib.get_dataset_from_module(
self.config["data_module"],
tokenizer,
)
elif self.config["data_source"] == "local":
dataset = example_data.create_dataset(
data_source=self.config["data_source"],
dataset=self.config["data_directory"],
tokenizer=tokenizer,
# Load training dataset
dataset = self._load_and_init_dataset(tokenizer, split=train_split)
self.compute_params(dataset)

# Load eval dataset if configured
if (self.config.get("eval_data_source", "") or
self.config.get("eval_data_module", "")):
eval_dataset = self._load_and_init_dataset(
tokenizer,
data_config_prefix="eval_",
split=eval_split
)
else:
dataset = example_data.create_dataset(
data_source="tfds",
dataset=self.config["dataset_name"],
tfds_download=self.config["tfds_download"],
)
self.compute_params(dataset)
dataset, _ = data_lib.post_init_dataset(
dataset,
tokenizer,
batch_size=self.config["batch_size"],
num_batches=self.config.get("num_batches", None),
max_prompt_length=self.config["rollout_config"].get(
"max_prompt_length", None
),
)
eval_dataset = None

rl_cluster = self.create_rl_cluster(tokenizer)
grpo_trainer = grpo_learner.GrpoLearner(
rl_cluster=rl_cluster,
reward_fns=self.obtain_reward_fn(),
algo_config=GrpoConfig(**self.config["grpo_config"]),
)
grpo_trainer.train(dataset)
grpo_trainer.train(dataset, eval_ds=eval_dataset)


def _setup_jax_pathways(pathways_bns: str):
Expand Down
2 changes: 1 addition & 1 deletion tunix/rl/rl_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ def _log_metrics(self, metrics_buffer: MetricsBuffer) -> None:
continue

if agg_value.dtype.kind in {"U", "S"}:
logging.info(
logging.debug(
"Skipping logging metric %s (dtype: %s)",
metric_name,
agg_value.dtype,
Expand Down
Loading