Skip to content
Open
16 changes: 14 additions & 2 deletions examples/rl/grpo/gsm8k/run_qwen3.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,24 @@ set -x # Enable xtrace
model_name=${model_name:-"Qwen3-1.7B-base"}
batch_size=${batch_size:-1}
num_batches=${num_batches:-3738}
num_test_batches=${num_test_batches:-100}
num_train_epochs=${num_train_epochs:-1}
warmup_ratio=${warmup_ratio:-0.1}
train_fraction=${train_fraction:-1.0}
train_split=${train_split:-"train"}
eval_split=${eval_split:-"test"}

echo "Using parameters:"
echo " Batch Size: $batch_size"
echo " Num Batches: $num_batches"
echo " Num Test Batches: $num_test_batches"
echo " Num Epochs: $num_train_epochs"
echo " Warmup Ratio: $warmup_ratio"
echo " Train Fraction: $train_fraction"
echo " Train Split: $train_split"
echo " Eval Split: $eval_split"
echo " Train Split: $train_split"
echo " Eval Split: $eval_split"

max_steps_float=$(awk "BEGIN {print $batch_size * $num_batches * $num_train_epochs * $train_fraction}")
max_steps=$(printf "%.0f" "$max_steps_float")
Expand All @@ -48,7 +56,7 @@ python3 -m tunix.cli.grpo_main \
model_config.rng_seed=42 \
actor_model_config.lora_config.rank=64 \
actor_model_config.lora_config.alpha=64.0 \
actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \
actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj"" \
actor_model_config.mesh.shape="(2,4)" \
actor_model_config.mesh.axis_names="('fsdp','tp')" \
rollout_model_config.mesh.shape="(2,4)" \
Expand All @@ -59,8 +67,12 @@ python3 -m tunix.cli.grpo_main \
dataset_name="gsm8k" \
batch_size=$batch_size \
num_batches=$num_batches \
num_test_batches=100 \
num_train_epochs=$num_train_epochs \
train_split=$train_split \
eval_data_source="tfds" \
eval_dataset_name="gsm8k" \
eval_num_batches=$num_test_batches \
eval_split=$eval_split \
rl_training_config.actor_optimizer_config.opt_type="adamw" \
rl_training_config.actor_optimizer_config.peak_value=3e-6 \
rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \
Expand Down
20 changes: 15 additions & 5 deletions examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,24 @@ set -x # Enable xtrace

# specify at cmd line to override defaults, e.g.
model_name=${model_name:-"Qwen3-1.7B-base"}
batch_size=${batch_size:-1}
num_batches=${num_batches:-3738}
batch_size=${batch_size:-16}
num_batches=${num_batches:-500}
num_test_batches=${num_test_batches:-10}
num_train_epochs=${num_train_epochs:-1}
warmup_ratio=${warmup_ratio:-0.1}
train_fraction=${train_fraction:-1.0}
train_split=${train_split:-"train"}
eval_split=${eval_split:-"test"}

echo "Using parameters:"
echo " Batch Size: $batch_size"
echo " Num Batches: $num_batches"
echo " Num Test Batches: $num_test_batches"
echo " Num Epochs: $num_train_epochs"
echo " Warmup Ratio: $warmup_ratio"
echo " Train Fraction: $train_fraction"
echo " Train Split: $train_split"
echo " Eval Split: $eval_split"

max_steps_float=$(awk "BEGIN {print $batch_size * $num_batches * $num_train_epochs * $train_fraction}")
max_steps=$(printf "%.0f" "$max_steps_float")
Expand All @@ -48,7 +54,7 @@ python3 -m tunix.cli.grpo_main \
model_config.rng_seed=42 \
actor_model_config.lora_config.rank=64 \
actor_model_config.lora_config.alpha=64.0 \
actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \
actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \
actor_model_config.mesh.shape="(2,4)" \
actor_model_config.mesh.axis_names="('fsdp','tp')" \
rollout_model_config.mesh.shape="(2,4)" \
Expand All @@ -59,8 +65,12 @@ python3 -m tunix.cli.grpo_main \
dataset_name="gsm8k" \
batch_size=$batch_size \
num_batches=$num_batches \
num_test_batches=100 \
num_train_epochs=$num_train_epochs \
train_split=$train_split \
eval_data_source="tfds" \
eval_dataset_name="gsm8k" \
eval_num_batches=$num_test_batches \
eval_split=$eval_split \
rl_training_config.actor_optimizer_config.opt_type="adamw" \
rl_training_config.actor_optimizer_config.peak_value=3e-6 \
rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \
Expand All @@ -73,7 +83,7 @@ python3 -m tunix.cli.grpo_main \
rl_training_config.actor_optimizer_config.b2=0.99 \
rl_training_config.actor_optimizer_config.weight_decay=0.1 \
rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \
rl_training_config.eval_every_n_steps=10 \
rl_training_config.eval_every_n_steps=100 \
rl_training_config.max_steps=$max_steps \
rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/${model_name}" \
rl_training_config.metrics_logging_options.flush_every_n_steps=20 \
Expand Down
13 changes: 13 additions & 0 deletions tunix/cli/base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,23 @@ max_target_length: 256
num_train_epochs: 1
train_fraction: 1.0
num_test_batches: 100
# Dataset split for training (e.g., "train", "validation")
train_split: "train"
# Controls the download flag only when using TFDS datasets. If false, the
# data_dir used will be set to `None` and chosen by default by tfds.
tfds_download: True

# Eval dataset configuration (optional)
eval_data_source: ""
eval_data_directory: ""
eval_data_module: ""
eval_dataset_name: ""
eval_num_batches: 100
# Dataset split for evaluation (e.g., "test", "validation")
eval_split: "test"
# Controls the download flag for eval TFDS datasets
eval_tfds_download: True

############################### Optimizer ###############################
# Optimizer config
optimizer_config: &base_optimizer_config
Expand Down
93 changes: 67 additions & 26 deletions tunix/cli/grpo_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,55 @@ def create_perf_config(self, cluster_config: rl_cluster_lib.ClusterConfig):
)
return perf_config

def _load_and_init_dataset(self, tokenizer, data_config_prefix: str = "", split: str = "train"):
"""Helper function to load and initialize a dataset.

Args:
tokenizer: Tokenizer for processing.
data_config_prefix: Optional prefix for config keys (e.g., "eval_" for
eval dataset).
split: The dataset split to use (e.g., "train", "test"). Only applies to
TFDS datasets.

Returns:
Initialized dataset ready for training.
"""
def get_key(suffix):
return f"{data_config_prefix}{suffix}"

# Load dataset
if self.config.get(get_key("data_module")):
dataset = data_lib.get_dataset_from_module(
self.config[get_key("data_module")],
tokenizer,
)
elif self.config[get_key("data_source")] == "local":
dataset = example_data.create_dataset(
data_source=self.config[get_key("data_source")],
dataset=self.config[get_key("data_directory")],
tokenizer=tokenizer,
)
else:
dataset = example_data.create_dataset(
data_source="tfds",
dataset=self.config[get_key("dataset_name")],
tfds_download=self.config.get(get_key("tfds_download"), False),
split=split,
)

# Post-initialize dataset
dataset, _ = data_lib.post_init_dataset(
dataset,
tokenizer,
batch_size=self.config["batch_size"],
num_batches=self.config.get(get_key("num_batches"), None),
max_prompt_length=self.config["rollout_config"].get(
"max_prompt_length", None
),
)

return dataset

def create_rl_cluster(self):
# Should not use LoRA for reference model.
if self.config["reference_model_config"].get("lora_config"):
Expand Down Expand Up @@ -202,33 +251,25 @@ def run_grpo_trainer(self):
)

tokenizer = grpo_trainer.rl_cluster.tokenizer
if self.config.get("data_module", None):
dataset = data_lib.get_dataset_from_module(
self.config["data_module"],
tokenizer,
)
elif self.config["data_source"] == "local":
dataset = example_data.create_dataset(
data_source=self.config["data_source"],
dataset=self.config["data_directory"],
tokenizer=tokenizer,
)
else:
dataset = example_data.create_dataset(
data_source="tfds",
dataset=self.config["dataset_name"],
tfds_download=self.config["tfds_download"],

# Get dataset splits from config with defaults
train_split = self.config.get("train_split", "train")
eval_split = self.config.get("eval_split", "test")

# Load training dataset
dataset = self._load_and_init_dataset(tokenizer, split=train_split)

# Load eval dataset if configured
eval_dataset = None
if (self.config.get("eval_data_source", "") or
self.config.get("eval_data_module", "")):
eval_dataset = self._load_and_init_dataset(
tokenizer,
data_config_prefix="eval_",
split=eval_split
)
dataset, _ = data_lib.post_init_dataset(
dataset,
tokenizer,
batch_size=self.config["batch_size"],
num_batches=self.config.get("num_batches", None),
max_prompt_length=self.config["rollout_config"].get(
"max_prompt_length", None
),
)
grpo_trainer.train(dataset)

grpo_trainer.train(dataset, eval_ds=eval_dataset)


def _setup_jax_pathways(pathways_bns: str):
Expand Down
14 changes: 9 additions & 5 deletions tunix/models/automodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,14 +292,18 @@ def download_model(
ValueError: If the model source is not supported for downloading.
"""

if model_source == ModelSource.KAGGLE:
if model_source in (ModelSource.KAGGLE, ModelSource.HUGGINGFACE):
from tunix.oss import utils as oss_utils # pylint: disable=g-import-not-at-top

return oss_utils.kaggle_pipeline(model_id_or_path, model_download_path)
elif model_source == ModelSource.HUGGINGFACE:
from tunix.oss import utils as oss_utils # pylint: disable=g-import-not-at-top
if model_download_path:
# Append the model name so different models don't share the same directory.
model_name = model_id_or_path.split('/')[-1]
model_download_path = os.path.join(model_download_path, model_name)

return oss_utils.hf_pipeline(model_id_or_path, model_download_path)
if model_source == ModelSource.KAGGLE:
return oss_utils.kaggle_pipeline(model_id_or_path, model_download_path)
else:
return oss_utils.hf_pipeline(model_id_or_path, model_download_path)
elif model_source == ModelSource.GCS:
return model_id_or_path
elif model_source == ModelSource.INTERNAL:
Expand Down
2 changes: 1 addition & 1 deletion tunix/rl/reward_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def _log_one_example():
for k, v in log_metrics.items():
logging.info(f"{k}:\t{v[0][0]}")
logging.info("=======================")
_log_one_example()
#_log_one_example()

return rewards_info

Expand Down
2 changes: 1 addition & 1 deletion tunix/rl/rl_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ def _log_metrics(self, metrics_buffer: MetricsBuffer) -> None:
continue

if agg_value.dtype.kind in {"U", "S"}:
logging.info(
logging.debug(
"Skipping logging metric %s (dtype: %s)",
metric_name,
agg_value.dtype,
Expand Down
Loading