Skip to content

Commit e1158b3

Browse files
committed
Add GSPO and small fixes
Signed-off-by: Vladimir Suvorov <[email protected]> Fix Signed-off-by: Vladimir Suvorov <[email protected]> Fix linter Signed-off-by: Vladimir Suvorov <[email protected]> Fix linter Signed-off-by: Vladimir Suvorov <[email protected]> Fix linter Signed-off-by: Vladimir Suvorov <[email protected]> Fix linter Signed-off-by: Vladimir Suvorov <[email protected]> Fix Signed-off-by: Vladimir Suvorov <[email protected]> ipython nb Signed-off-by: Vladimir Suvorov <[email protected]> Fix Signed-off-by: Vladimir Suvorov <[email protected]> Simplify Signed-off-by: Vladimir Suvorov <[email protected]> Fix Signed-off-by: Vladimir Suvorov <[email protected]> Fix Signed-off-by: Vladimir Suvorov <[email protected]> Fix Signed-off-by: Vladimir Suvorov <[email protected]> Fix GRPO Signed-off-by: Vladimir Suvorov <[email protected]> Run GRPO/GSPO Signed-off-by: Vladimir Suvorov <[email protected]> More fixes Signed-off-by: Vladimir Suvorov <[email protected]> More fixes Signed-off-by: Vladimir Suvorov <[email protected]> More fixes Signed-off-by: Vladimir Suvorov <[email protected]> More fixes Signed-off-by: Vladimir Suvorov <[email protected]> More fixes Signed-off-by: Vladimir Suvorov <[email protected]> More fixes Signed-off-by: Vladimir Suvorov <[email protected]> More fixes Signed-off-by: Vladimir Suvorov <[email protected]> Fix for Qwen Signed-off-by: Vladimir Suvorov <[email protected]> Fix colab Signed-off-by: Vladimir Suvorov <[email protected]> Restored Signed-off-by: Vladimir Suvorov <[email protected]> Fix Signed-off-by: Vladimir Suvorov <[email protected]> Fix Signed-off-by: Vladimir Suvorov <[email protected]>
1 parent 238a410 commit e1158b3

File tree

6 files changed

+209
-157
lines changed

6 files changed

+209
-157
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
2323
MaxText is a high performance, highly scalable, open-source LLM library and reference implementation written in pure Python/[JAX](https://docs.jax.dev/en/latest/jax-101.html) and targeting Google Cloud TPUs and GPUs for training.
2424

25-
MaxText provides a library of high performance models to choose from, including Gemma, Llama, DeepSeek, Qwen, and Mistral. For each of these models, MaxText supports pre-training (up to tens of thousands of chips) and scalable post-training, with popular techniques like Supervised Fine-Tuning (SFT) and Group Relative Policy Optimization (GRPO, a type of Reinforcement Learning).
25+
MaxText provides a library of high performance models to choose from, including Gemma, Llama, DeepSeek, Qwen, and Mistral. For each of these models, MaxText supports pre-training (up to tens of thousands of chips) and scalable post-training, with popular techniques like Supervised Fine-Tuning (SFT) and Group Relative Policy Optimization (GRPO, a type of Reinforcement Learning) and Group Sequence Policy Optimization (GSPO, a type of Reinforcement Learning)
2626

2727
MaxText achieves high Model FLOPs Utilization (MFU) and tokens/second from single host to very large clusters while staying simple and largely "optimization-free" thanks to the power of JAX and the XLA compiler.
2828

@@ -70,7 +70,7 @@ Our goal is to provide a variety of models (dimension “a”) and techniques (d
7070
Check out these getting started guides:
7171

7272
* [SFT](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/llama3.1/8b/run_sft.sh) (Supervised Fine Tuning)
73-
* [GRPO](https://maxtext.readthedocs.io/en/latest/tutorials/grpo.html) (Group Relative Policy Optimization)
73+
* [GRPO / GSPO](https://maxtext.readthedocs.io/en/latest/tutorials/grpo.html) (Group Relative & Group Sequence Policy Optimization – pass `loss_algo=gspo-token` to run GSPO)
7474

7575
### Model library
7676

docs/tutorials/grpo.md

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ This tutorial demonstrates step-by-step instructions for setting up the environm
2020

2121
GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by generating multiple responses for a given prompt, evaluating these responses using a reward model, and then calculating a relative advantage based on the group's performance to update the policy.
2222

23-
We use Tunix as the library for GRPO.
23+
We use Tunix as the library for GRPO/GSPO.
2424
And we use vLLM as the library for efficient model inference and generation.
2525

2626
In this tutorial we use a single host TPUVM such as `v6e-8/v5p-8`. Let's get started!
@@ -66,3 +66,22 @@ The overview of the what this run will do is as follows:
6666
2. Evaluate the policy model's performance on GSM8K math reasoning benchmark.
6767
3. Train the policy model using GRPO.
6868
4. Evaluate the policy model's performance on GSM8K math reasoning benchmark after the post-training with GRPO.
69+
70+
GSPO (Group Sequence Policy Optimization)
71+
MaxText can also run the GSPO variant by setting `loss_algo=gspo-token` when invoking `train_rl.py` (or when constructing the pyconfig argv list).
72+
73+
## Run GSPO
74+
75+
Finally, run the command
76+
77+
```
78+
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
79+
model_name=llama3.1-8b \
80+
tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
81+
load_parameters_path=gs://path/to/checkpoint/0/items \
82+
run_name=$WORKLOAD \
83+
base_output_directory=$OUTPUT_PATH \
84+
hf_access_token=$HF_TOKEN \
85+
loss_algo=gspo-token
86+
```
87+

docs/tutorials/grpo_with_pathways.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ This tutorial demonstrates step-by-step instructions for setting up the environm
2020

2121
GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by generating multiple responses for a given prompt, evaluating these responses using a reward model, and then calculating a relative advantage based on the group's performance to update the policy.
2222

23+
GSPO support
24+
Some workloads prefer Group Sequence Policy Optimization (GSPO), which uses the same infrastructure but a different loss.
25+
To switch from GRPO to GSPO, add the following override when invoking `train_rl.py` (or when building the `pyconfig` argv list):
26+
```
27+
loss_algo=gspo-token
28+
```
29+
No other changes are required—the rest of this tutorial applies equally to GSPO runs.
30+
2331
We use Tunix as the library for GRPO.
2432
And we use vLLM as the library for efficient model inference and generation.
2533

src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb

Lines changed: 143 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,18 @@
44
"cell_type": "markdown",
55
"metadata": {},
66
"source": [
7-
"# GRPO Llama3.1-8B Demo: Direct Function Call\n",
7+
"# GRPO/GSPO Llama3.1-8B Demo\n",
88
"\n",
9-
"This notebook demonstrates GRPO training by directly calling the `rl_train` function from `rl_trainer.py`.\n",
9+
"This notebook demonstrates GRPO (Group Relative Policy Optimization) training using the unified `rl_train` function or GSPO (Group Sequence Policy Optimization) - the change is in loss function which is a parameter\n",
1010
"\n",
11-
"## What is GRPO?\n",
11+
"## What is GRPO/GSPO?\n",
1212
"\n",
13-
"GRPO (Group Relative Policy Optimization) is an RL algorithm that enhances reasoning abilities of LLMs by:\n",
13+
"GRPO/GSPO is an RL algorithm that enhances reasoning abilities of LLMs by:\n",
1414
"1. Generating multiple responses for each prompt\n",
1515
"2. Evaluating responses using reward models \n",
1616
"3. Calculating relative advantages to update the policy\n",
1717
"\n",
18-
"\n",
19-
"This notebook imports and calls the `rl_train` function \n",
18+
"The difference is in the loss function - either it's optimizing each token (GRPO) or the whole sequence(GSPO).\n",
2019
"\n",
2120
"## Hardware Requirements\n",
2221
"\n",
@@ -28,9 +27,24 @@
2827
"cell_type": "markdown",
2928
"metadata": {},
3029
"source": [
31-
"## Setup\n",
30+
"### Get Your Hugging Face Token\n",
31+
"\n",
32+
"To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n",
33+
"\n",
34+
"**Follow these steps to get your token:**\n",
35+
"\n",
36+
"1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:\n",
37+
" * [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n",
38+
"\n",
39+
"2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n",
3240
"\n",
33-
"Install dependencies and set up the environment:"
41+
"3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n",
42+
"\n",
43+
"4. **Copy the generated token**. You will need to paste it in the next step.\n",
44+
"\n",
45+
"**Follow these steps to store your token:**\n",
46+
"\n",
47+
"Just put your token in the line below"
3448
]
3549
},
3650
{
@@ -39,30 +53,17 @@
3953
"metadata": {},
4054
"outputs": [],
4155
"source": [
42-
"# Clone MaxText repository\n",
43-
"!git clone https://github.com/AI-Hypercomputer/maxtext.git\n",
44-
"%cd maxtext"
56+
"HF_TOKEN = \"\" # Set HF_TOKEN environment variable\n"
4557
]
4658
},
4759
{
48-
"cell_type": "code",
49-
"execution_count": null,
60+
"cell_type": "markdown",
5061
"metadata": {},
51-
"outputs": [],
5262
"source": [
53-
"# Install dependencies\n",
54-
"!chmod +x setup.sh\n",
55-
"!./setup.sh\n",
56-
"\n",
57-
"# Install GRPO-specific dependencies\n",
58-
"!./src/MaxText/examples/install_tunix_vllm_requirement.sh\n",
59-
"\n",
60-
"# Install additional requirements\n",
61-
"%pip install --force-reinstall numpy==2.1.2\n",
62-
"%pip install nest_asyncio\n",
63+
"## Setup\n",
6364
"\n",
64-
"import nest_asyncio\n",
65-
"nest_asyncio.apply() # Fix for Colab event loop"
65+
"Install dependencies and set up the environment:\n",
66+
"https://maxtext.readthedocs.io/latest/tutorials/grpo.html#from-github"
6667
]
6768
},
6869
{
@@ -71,9 +72,23 @@
7172
"source": [
7273
"## Configuration\n",
7374
"\n",
74-
"Set up the training parameters:"
75+
"Set up the training parameters. We do not use Pathways and do use a single host. Defaults are hardcoded for Llama3.1-8B:"
7576
]
7677
},
78+
{
79+
"cell_type": "code",
80+
"execution_count": null,
81+
"metadata": {},
82+
"outputs": [],
83+
"source": [
84+
"!cd ~/maxtext/src/ # make sure we are in the right directory"
85+
]
86+
},
87+
{
88+
"cell_type": "markdown",
89+
"metadata": {},
90+
"source": []
91+
},
7792
{
7893
"cell_type": "code",
7994
"execution_count": null,
@@ -82,20 +97,48 @@
8297
"source": [
8398
"# Configuration for GRPO training\n",
8499
"import os\n",
100+
"from re import M\n",
101+
"import MaxText\n",
102+
"\n",
103+
"# Set up paths (adjust if needed)\n",
104+
"MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)\n",
105+
"RUN_NAME=\"grpo_test\"\n",
106+
"# Hardcoded defaults for Llama3.1-8B\n",
107+
"MODEL_NAME = \"llama3.1-8b\"\n",
108+
"HF_REPO_ID = \"meta-llama/Llama-3.1-8B-Instruct\"\n",
109+
"CHAT_TEMPLATE_PATH = f\"{MAXTEXT_REPO_ROOT}/examples/chat_templates/gsm8k_rl.json\"\n",
110+
"LOSS_ALGO=\"grpo\" # or \"gspo-token\" if you want to use GSPO\n",
111+
"\n",
112+
"# Required: Set these before running\n",
113+
"MODEL_CHECKPOINT_PATH = \"\" # Update this!\n",
114+
"if not MODEL_CHECKPOINT_PATH:\n",
115+
" raise RuntimeError(\"MODEL_CHECKPOINT_PATH is not set\")\n",
116+
" \n",
117+
"OUTPUT_DIRECTORY = \"/tmp/gpo_output\" # Update this!\n",
118+
"os.environ[\"HF_TOKEN\"] = HF_TOKEN\n",
85119
"\n",
86-
"# Set up paths\n",
87-
"MAXTEXT_REPO_ROOT = os.path.expanduser(\"~\") + \"/maxtext\"\n",
88-
"print(f\"MaxText Home directory: {MAXTEXT_REPO_ROOT}\")\n",
120+
"if HF_TOKEN:\n",
121+
" login(token=HF_TOKEN)\n",
122+
" print(\"Authenticated with Hugging Face\")\n",
123+
"else:\n",
124+
" print(\"Authentication failed: Hugging Face token not set\")\n",
89125
"\n",
90-
"# Training configuration\n",
91-
"MODEL_CHECKPOINT_PATH = \"gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items\"\n",
92-
"OUTPUT_DIRECTORY = \"/tmp/grpo_output\"\n",
126+
"# Optional: Override training parameters\n",
93127
"STEPS = 10 # Reduced for demo purposes\n",
94-
"HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"your_hf_token_here\")\n",
128+
"PER_DEVICE_BATCH_SIZE = 1\n",
129+
"LEARNING_RATE = 3e-6\n",
130+
"NUM_GENERATIONS = 2\n",
131+
"GRPO_BETA = 0.08\n",
132+
"GRPO_EPSILON = 0.2\n",
133+
"CHIPS_PER_VM = 1\n",
95134
"\n",
96-
"print(f\"Model checkpoint: {MODEL_CHECKPOINT_PATH}\")\n",
97-
"print(f\"Output directory: {OUTPUT_DIRECTORY}\")\n",
98-
"print(f\"Training steps: {STEPS}\")"
135+
"print(f\"📁 MaxText Home: {MAXTEXT_REPO_ROOT}\")\n",
136+
"print(f\"🤖 Model: {MODEL_NAME}\")\n",
137+
"print(f\"📦 Checkpoint: {MODEL_CHECKPOINT_PATH}\")\n",
138+
"print(f\"💾 Output: {OUTPUT_DIRECTORY}\")\n",
139+
"print(f\"🔑 HF Token: {'✅ Set' if HF_TOKEN else '❌ Missing - set HF_TOKEN env var'}\")\n",
140+
"print(f\"📊 Steps: {STEPS}\")\n",
141+
"print(f\"Loss Algorithm : {LOSS_ALGO}\")"
99142
]
100143
},
101144
{
@@ -104,24 +147,25 @@
104147
"metadata": {},
105148
"outputs": [],
106149
"source": [
107-
"# Import GRPO training function directly\n",
108-
"import sys\n",
150+
"# Import required modules\n",
109151
"import os\n",
152+
"import sys\n",
110153
"from pathlib import Path\n",
111154
"\n",
112155
"# Add MaxText to Python path\n",
113-
"maxtext_path = Path(MAXTEXT_REPO_ROOT) / \"src\" / \"MaxText\"\n",
156+
"maxtext_path = Path(MAXTEXT_REPO_ROOT) \n",
114157
"sys.path.insert(0, str(maxtext_path))\n",
115158
"\n",
116-
"# Import required modules\n",
117-
"from MaxText import pyconfig\n",
118-
"from MaxText.train_rl import rl_train\n",
159+
"from MaxText import pyconfig, max_utils\n",
160+
"from MaxText.rl.train_rl import rl_train, setup_configs_and_devices\n",
161+
"import jax\n",
119162
"\n",
120-
"print(\"✅ Successfully imported GRPO training function\")\n",
121-
"print(f\"📁 MaxText path: {maxtext_path}\")\n",
122-
"print(\"\\n\" + \"=\"*80)\n",
123-
"print(\"Starting GRPO Training...\")\n",
124-
"print(\"=\"*80)"
163+
"# Initialize JAX and Pathways\n",
164+
"os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"0\"\n",
165+
"os.environ[\"SKIP_JAX_PRECOMPILE\"] = \"1\" # Faster startup for vLLM\n",
166+
"\n",
167+
"print(\"✅ Successfully imported modules\")\n",
168+
"print(f\"📁 MaxText path: {maxtext_path}\")"
125169
]
126170
},
127171
{
@@ -131,28 +175,40 @@
131175
"outputs": [],
132176
"source": [
133177
"# Build configuration for GRPO training\n",
178+
"config_file = os.path.join(MAXTEXT_REPO_ROOT, \"configs/rl.yml\")\n",
179+
"\n",
180+
"# Verify chat template exists\n",
181+
"if not os.path.exists(CHAT_TEMPLATE_PATH)):\n",
182+
" raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n",
183+
"\n",
184+
"# Build argv list for pyconfig.initialize()\n",
134185
"config_argv = [\n",
135-
" \"\", # Placeholder for argv[0]\n",
136-
" \"src/MaxText/configs/grpo.yml\", # Base config\n",
137-
" f\"model_name=llama3.1-8b\",\n",
138-
" f\"tokenizer_path=meta-llama/Llama-3.1-8B-Instruct\",\n",
186+
" \"\", # argv[0] placeholder\n",
187+
" config_file,\n",
188+
" f\"model_name={MODEL_NAME}\",\n",
189+
" f\"tokenizer_path={HF_REPO_ID}\",\n",
190+
" f\"run_name={RUN_NAME}\",\n",
191+
" f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n",
139192
" f\"load_parameters_path={MODEL_CHECKPOINT_PATH}\",\n",
140193
" f\"base_output_directory={OUTPUT_DIRECTORY}\",\n",
141194
" f\"hf_access_token={HF_TOKEN}\",\n",
142195
" f\"steps={STEPS}\",\n",
143-
" \"per_device_batch_size=1\",\n",
144-
" \"learning_rate=3e-6\",\n",
145-
" \"num_generations=2\",\n",
146-
" \"grpo_beta=0.08\",\n",
147-
" \"grpo_epsilon=0.2\",\n",
148-
" \"chips_per_vm=4\"\n",
196+
" f\"per_device_batch_size={PER_DEVICE_BATCH_SIZE}\",\n",
197+
" f\"learning_rate={LEARNING_RATE}\",\n",
198+
" f\"num_generations={NUM_GENERATIONS}\",\n",
199+
" f\"grpo_beta={GRPO_BETA}\",\n",
200+
" f\"grpo_epsilon={GRPO_EPSILON}\",\n",
201+
" f\"chips_per_vm={CHIPS_PER_VM}\",\n",
202+
" f\"loss_algo={LOSS_ALGO}\",\n",
203+
" \"use_pathways=False\"\n",
149204
"]\n",
150205
"\n",
151-
"# Create configuration object\n",
152-
"config = pyconfig.Config()\n",
153-
"config.parse_flags(config_argv)\n",
206+
"# Initialize configuration\n",
207+
"print(f\"🔧 Initializing configuration from: {config_file}\")\n",
208+
"config = pyconfig.initialize(config_argv)\n",
209+
"max_utils.print_system_information()\n",
154210
"\n",
155-
"print(\"✅ Configuration created successfully\")\n",
211+
"print(\"\\n✅ Configuration initialized successfully\")\n",
156212
"print(f\"📊 Training steps: {config.steps}\")\n",
157213
"print(f\"📁 Output directory: {config.base_output_directory}\")\n",
158214
"print(f\"🤖 Model: {config.model_name}\")"
@@ -164,33 +220,46 @@
164220
"metadata": {},
165221
"outputs": [],
166222
"source": [
167-
"# Execute GRPO training directly\n",
223+
"# Execute GRPO/GSPO training\n",
224+
"print(\"\\n\" + \"=\"*80)\n",
225+
"print(\"🚀 Starting Training...\")\n",
226+
"print(\"=\"*80)\n",
227+
"print(1)\n",
168228
"try:\n",
169-
" # Call the rl_train function\n",
170-
" grpo_trainer, rl_cluster = rl_train(config)\n",
229+
" # Call the rl_train function (it handles everything internally)\n",
230+
" rl_train(config)\n",
171231
" \n",
172232
" print(\"\\n\" + \"=\"*80)\n",
173-
" print(\"GRPO Training Completed Successfully!\")\n",
233+
" print(\"✅ Training Completed Successfully!\")\n",
174234
" print(\"=\"*80)\n",
175-
" print(f\"📁 Checkpoints and logs saved to: {config.base_output_directory}\")\n",
176-
" print(f\"🎯 Final model ready for inference!\")\n",
235+
" print(f\"📁 Checkpoints saved to: {config.checkpoint_dir}\")\n",
236+
" print(f\"📊 TensorBoard logs: {config.tensorboard_dir}\")\n",
237+
" print(f\"🎯 Model ready for inference!\")\n",
177238
" \n",
178239
"except Exception as e:\n",
179240
" print(\"\\n\" + \"=\"*80)\n",
180-
" print(\" GRPO Training Failed!\")\n",
241+
" print(\"❌Training Failed!\")\n",
181242
" print(\"=\"*80)\n",
182243
" print(f\"Error: {str(e)}\")\n",
183-
" print(\"\\nPlease check the error message and try again.\")"
244+
" import traceback\n",
245+
" traceback.print_exc()\n",
246+
" print(\"\\n💡 Common issues:\")\n",
247+
" print(\" - Check that MODEL_CHECKPOINT_PATH points to a valid checkpoint\")\n",
248+
" print(\" - Ensure HF_TOKEN environment variable is set\")\n",
249+
" print(\" - Verify OUTPUT_DIRECTORY is writable\")\n",
250+
" print(\" - Check hardware requirements (TPU/GPU availability)\")"
184251
]
185252
},
186253
{
187254
"cell_type": "markdown",
188255
"metadata": {},
189256
"source": [
190-
"### 📚 **Learn More**\n",
191-
"- See `src/MaxText/examples/grpo_runner.py` for CLI usage\n",
192-
"- Check `src/MaxText/configs/grpo.yml` for configuration options\n",
193-
"- Read `src/MaxText/examples/README.md` for more examples"
257+
"## 📚 Learn More\n",
258+
"\n",
259+
"- **CLI Usage**: Run `python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml --model_name=llama3.1-8b ...`\n",
260+
"- **Configuration**: See `src/MaxText/configs/rl.yml` for all available options\n",
261+
"- **Documentation**: Check `src/MaxText/rl/train_rl.py` for the `rl_train` function implementation\n",
262+
"- **Examples**: See other examples in `src/MaxText/examples/`"
194263
]
195264
}
196265
],

0 commit comments

Comments
 (0)