diff --git a/examples/rl/grpo/gsm8k/run_qwen3.sh b/examples/rl/grpo/gsm8k/run_qwen3.sh index 7b4daa849..1aac0f38d 100755 --- a/examples/rl/grpo/gsm8k/run_qwen3.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3.sh @@ -27,19 +27,22 @@ echo " Batch Size: $batch_size" echo " Num Epochs: $num_train_epochs" echo " Warmup Ratio: $warmup_ratio" echo " Train Fraction: $train_fraction" +echo " Train Split: $train_split" +echo " Eval Split: $eval_split" python3 -m tunix.cli.grpo_main \ base_config.yaml \ model_config.model_name=${model_name} \ model_config.model_id=Qwen/${model_name} \ model_config.model_source=huggingface \ + model_config.model_download_path="/tmp/models/${model_name}" \ model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/${model_name}" \ model_config.mesh.shape="(2,4)" \ model_config.mesh.axis_names="('fsdp','tp')" \ model_config.rng_seed=42 \ actor_model_config.lora_config.rank=64 \ actor_model_config.lora_config.alpha=64.0 \ - actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \ + actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ rollout_model_config.mesh.shape="(2,4)" \ @@ -51,6 +54,11 @@ python3 -m tunix.cli.grpo_main \ batch_size=$batch_size \ num_test_batches=100 \ num_train_epochs=$num_train_epochs \ + train_split=$train_split \ + eval_data_source="tfds" \ + eval_dataset_name="gsm8k" \ + eval_num_batches=$num_test_batches \ + eval_split=$eval_split \ rl_training_config.actor_optimizer_config.opt_type="adamw" \ rl_training_config.actor_optimizer_config.peak_value=3e-6 \ rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \ diff --git a/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh b/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh old mode 100644 new mode 100755 index 1a84f156a..926eda833 --- a/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh @@ -17,18 +17,24 @@ set -x # Enable xtrace # specify at cmd line to override defaults, e.g. model_name=${model_name:-"Qwen3-1.7B-base"} -batch_size=${batch_size:-1} -num_batches=${num_batches:-3738} +batch_size=${batch_size:-16} +num_batches=${num_batches:-500} +num_test_batches=${num_test_batches:-10} num_train_epochs=${num_train_epochs:-1} warmup_ratio=${warmup_ratio:-0.1} train_fraction=${train_fraction:-1.0} +train_split=${train_split:-"train"} +eval_split=${eval_split:-"test"} echo "Using parameters:" echo " Batch Size: $batch_size" echo " Num Batches: $num_batches" +echo " Num Test Batches: $num_test_batches" echo " Num Epochs: $num_train_epochs" echo " Warmup Ratio: $warmup_ratio" echo " Train Fraction: $train_fraction" +echo " Train Split: $train_split" +echo " Eval Split: $eval_split" max_steps_float=$(awk "BEGIN {print $batch_size * $num_batches * $num_train_epochs * $train_fraction}") max_steps=$(printf "%.0f" "$max_steps_float") @@ -42,13 +48,14 @@ python3 -m tunix.cli.grpo_main \ model_config.model_name=${model_name} \ model_config.model_id=Qwen/${model_name} \ model_config.model_source=huggingface \ + model_config.model_download_path="/tmp/models/${model_name}" \ model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/${model_name}" \ model_config.mesh.shape="(2,4)" \ model_config.mesh.axis_names="('fsdp','tp')" \ model_config.rng_seed=42 \ actor_model_config.lora_config.rank=64 \ actor_model_config.lora_config.alpha=64.0 \ - actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \ + actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ rollout_model_config.mesh.shape="(2,4)" \ @@ -59,8 +66,12 @@ python3 -m tunix.cli.grpo_main \ dataset_name="gsm8k" \ batch_size=$batch_size \ num_batches=$num_batches \ - num_test_batches=100 \ num_train_epochs=$num_train_epochs \ + train_split=$train_split \ + eval_data_source="tfds" \ + eval_dataset_name="gsm8k" \ + eval_num_batches=$num_test_batches \ + eval_split=$eval_split \ rl_training_config.actor_optimizer_config.opt_type="adamw" \ rl_training_config.actor_optimizer_config.peak_value=3e-6 \ rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \ @@ -73,7 +84,7 @@ python3 -m tunix.cli.grpo_main \ rl_training_config.actor_optimizer_config.b2=0.99 \ rl_training_config.actor_optimizer_config.weight_decay=0.1 \ rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \ - rl_training_config.eval_every_n_steps=10 \ + rl_training_config.eval_every_n_steps=100 \ rl_training_config.max_steps=$max_steps \ rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/${model_name}" \ rl_training_config.metrics_logging_options.flush_every_n_steps=20 \ diff --git a/tunix/cli/base_config.yaml b/tunix/cli/base_config.yaml index 789343280..c4810e438 100644 --- a/tunix/cli/base_config.yaml +++ b/tunix/cli/base_config.yaml @@ -98,10 +98,23 @@ max_target_length: 256 num_train_epochs: 1 train_fraction: 1.0 num_test_batches: 100 +# Dataset split for training (e.g., "train", "validation") +train_split: "train" # Controls the download flag only when using TFDS datasets. If false, the # data_dir used will be set to `None` and chosen by default by tfds. tfds_download: True +# Eval dataset configuration (optional) +eval_data_source: "" +eval_data_directory: "" +eval_data_module: "" +eval_dataset_name: "" +eval_num_batches: 100 +# Dataset split for evaluation (e.g., "test", "validation") +eval_split: "test" +# Controls the download flag for eval TFDS datasets +eval_tfds_download: True + ############################### Optimizer ############################### # Optimizer config optimizer_config: &base_optimizer_config diff --git a/tunix/cli/grpo_main.py b/tunix/cli/grpo_main.py index 440247897..1cb3b775a 100644 --- a/tunix/cli/grpo_main.py +++ b/tunix/cli/grpo_main.py @@ -21,6 +21,7 @@ from flax import nnx import jax import jax.numpy as jnp +from transformers import Any from tunix.cli import config from tunix.cli.utils import data as data_lib from tunix.cli.utils import model as model_lib @@ -155,6 +156,55 @@ def create_perf_config(self, cluster_config: rl_cluster_lib.ClusterConfig): ) return perf_config + def _load_and_init_dataset(self, tokenizer: Any, *, data_config_prefix: str = "", split: str = "train"): + """Helper function to load and initialize a dataset. + + Args: + tokenizer: Tokenizer for processing. + data_config_prefix: Optional prefix for config keys (e.g., "eval_" for + eval dataset). + split: The dataset split to use (e.g., "train", "test"). Only applies to + TFDS datasets. + + Returns: + Initialized dataset ready for training. + """ + def get_key(suffix): + return f"{data_config_prefix}{suffix}" + + # Load dataset + if self.config.get(get_key("data_module")): + dataset = data_lib.get_dataset_from_module( + self.config[get_key("data_module")], + tokenizer, + ) + elif self.config[get_key("data_source")] == "local": + dataset = example_data.create_dataset( + data_source=self.config[get_key("data_source")], + dataset=self.config[get_key("data_directory")], + tokenizer=tokenizer, + ) + else: + dataset = example_data.create_dataset( + data_source="tfds", + dataset=self.config[get_key("dataset_name")], + tfds_download=self.config.get(get_key("tfds_download"), False), + split=split, + ) + + # Post-initialize dataset + dataset, _ = data_lib.post_init_dataset( + dataset, + tokenizer, + batch_size=self.config["batch_size"], + num_batches=self.config.get(get_key("num_batches"), None), + max_prompt_length=self.config["rollout_config"].get( + "max_prompt_length", None + ), + ) + + return dataset + def create_rl_cluster(self, tokenizer): # Should not use LoRA for reference model. if self.config["reference_model_config"].get("lora_config"): @@ -260,41 +310,33 @@ def run_grpo_trainer(self): self.config["tokenizer_config"], self.config["tokenizer_config"]["tokenizer_path"], ) + tokenizer = grpo_trainer.rl_cluster.tokenizer + # Get dataset splits from config with defaults + train_split = self.config.get("train_split", "train") + eval_split = self.config.get("eval_split", "test") - if self.config.get("data_module", None): - dataset = data_lib.get_dataset_from_module( - self.config["data_module"], - tokenizer, - ) - elif self.config["data_source"] == "local": - dataset = example_data.create_dataset( - data_source=self.config["data_source"], - dataset=self.config["data_directory"], - tokenizer=tokenizer, + # Load training dataset + dataset = self._load_and_init_dataset(tokenizer, split=train_split) + self.compute_params(dataset) + + # Load eval dataset if configured + if (self.config.get("eval_data_source", "") or + self.config.get("eval_data_module", "")): + eval_dataset = self._load_and_init_dataset( + tokenizer, + data_config_prefix="eval_", + split=eval_split ) else: - dataset = example_data.create_dataset( - data_source="tfds", - dataset=self.config["dataset_name"], - tfds_download=self.config["tfds_download"], - ) - self.compute_params(dataset) - dataset, _ = data_lib.post_init_dataset( - dataset, - tokenizer, - batch_size=self.config["batch_size"], - num_batches=self.config.get("num_batches", None), - max_prompt_length=self.config["rollout_config"].get( - "max_prompt_length", None - ), - ) + eval_dataset = None + rl_cluster = self.create_rl_cluster(tokenizer) grpo_trainer = grpo_learner.GrpoLearner( rl_cluster=rl_cluster, reward_fns=self.obtain_reward_fn(), algo_config=GrpoConfig(**self.config["grpo_config"]), ) - grpo_trainer.train(dataset) + grpo_trainer.train(dataset, eval_ds=eval_dataset) def _setup_jax_pathways(pathways_bns: str): diff --git a/tunix/rl/rl_cluster.py b/tunix/rl/rl_cluster.py index ff6ed8f1c..6cd28c8ba 100644 --- a/tunix/rl/rl_cluster.py +++ b/tunix/rl/rl_cluster.py @@ -683,7 +683,7 @@ def _log_metrics(self, metrics_buffer: MetricsBuffer) -> None: continue if agg_value.dtype.kind in {"U", "S"}: - logging.info( + logging.debug( "Skipping logging metric %s (dtype: %s)", metric_name, agg_value.dtype,