Skip to content

Commit c45c6c5

Browse files
Merge pull request #2693 from AI-Hypercomputer:gspo_and_fixes
PiperOrigin-RevId: 837267129
2 parents ed517cf + 58f10c9 commit c45c6c5

File tree

7 files changed

+382
-346
lines changed

7 files changed

+382
-346
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
2323
MaxText is a high performance, highly scalable, open-source LLM library and reference implementation written in pure Python/[JAX](https://docs.jax.dev/en/latest/jax-101.html) and targeting Google Cloud TPUs and GPUs for training.
2424

25-
MaxText provides a library of high performance models to choose from, including Gemma, Llama, DeepSeek, Qwen, and Mistral. For each of these models, MaxText supports pre-training (up to tens of thousands of chips) and scalable post-training, with popular techniques like Supervised Fine-Tuning (SFT) and Group Relative Policy Optimization (GRPO, a type of Reinforcement Learning).
25+
MaxText provides a library of high performance models to choose from, including Gemma, Llama, DeepSeek, Qwen, and Mistral. For each of these models, MaxText supports pre-training (up to tens of thousands of chips) and scalable post-training, with popular techniques like Supervised Fine-Tuning (SFT) and Group Relative Policy Optimization (GRPO, a type of Reinforcement Learning) and Group Sequence Policy Optimization (GSPO, a type of Reinforcement Learning).
2626

2727
MaxText achieves high Model FLOPs Utilization (MFU) and tokens/second from single host to very large clusters while staying simple and largely "optimization-free" thanks to the power of JAX and the XLA compiler.
2828

@@ -70,7 +70,7 @@ Our goal is to provide a variety of models (dimension “a”) and techniques (d
7070
Check out these getting started guides:
7171

7272
* [SFT](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/llama3.1/8b/run_sft.sh) (Supervised Fine Tuning)
73-
* [GRPO](https://maxtext.readthedocs.io/en/latest/tutorials/grpo.html) (Group Relative Policy Optimization)
73+
* [GRPO / GSPO](https://maxtext.readthedocs.io/en/latest/tutorials/grpo.html) (Group Relative & Group Sequence Policy Optimization – pass `loss_algo=gspo-token` to run GSPO)
7474

7575
### Model library
7676

docs/tutorials/grpo.md

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ This tutorial demonstrates step-by-step instructions for setting up the environm
2020

2121
GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by generating multiple responses for a given prompt, evaluating these responses using a reward model, and then calculating a relative advantage based on the group's performance to update the policy.
2222

23-
We use Tunix as the library for GRPO.
23+
We use Tunix as the library for GRPO/GSPO.
2424
And we use vLLM as the library for efficient model inference and generation.
2525

2626
In this tutorial we use a single host TPUVM such as `v6e-8/v5p-8`. Let's get started!
@@ -112,3 +112,22 @@ The overview of the what this run will do is as follows:
112112
2. Evaluate the policy model's performance on GSM8K math reasoning benchmark.
113113
3. Train the policy model using GRPO.
114114
4. Evaluate the policy model's performance on GSM8K math reasoning benchmark after the post-training with GRPO.
115+
116+
GSPO (Group Sequence Policy Optimization)
117+
MaxText can also run the GSPO variant by setting `loss_algo=gspo-token` when invoking `train_rl.py` (or when constructing the pyconfig argv list).
118+
119+
## Run GSPO
120+
121+
Finally, run the command
122+
123+
```
124+
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
125+
model_name=llama3.1-8b \
126+
tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
127+
load_parameters_path=gs://path/to/checkpoint/0/items \
128+
run_name=$WORKLOAD \
129+
base_output_directory=$OUTPUT_PATH \
130+
hf_access_token=$HF_TOKEN \
131+
loss_algo=gspo-token
132+
```
133+

docs/tutorials/grpo_with_pathways.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ This tutorial demonstrates step-by-step instructions for setting up the environm
2020

2121
GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by generating multiple responses for a given prompt, evaluating these responses using a reward model, and then calculating a relative advantage based on the group's performance to update the policy.
2222

23+
GSPO support
24+
Some workloads prefer Group Sequence Policy Optimization (GSPO), which uses the same infrastructure but a different loss.
25+
To switch from GRPO to GSPO, add the following override when invoking `train_rl.py` (or when building the `pyconfig` argv list):
26+
```
27+
loss_algo=gspo-token
28+
```
29+
No other changes are required—the rest of this tutorial applies equally to GSPO runs.
30+
2331
We use Tunix as the library for GRPO.
2432
And we use vLLM as the library for efficient model inference and generation.
2533

docs/tutorials/how_to_run_colabs.md

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ Upload notebooks or mount your GitHub repo
5959
2. Try:
6060
- `sft_qwen3_demo.ipynb`
6161
- `sft_llama3_demo.ipynb`
62-
- `grpo_llama3_demo.ipynb`
62+
- `rl_llama3_demo.ipynb` (GRPO/GSPO training)
6363

6464

6565
> **Tip:** If Colab disconnects, re-enable TPU and re-run setup cells. Save checkpoints to GCS or Drive.
@@ -125,22 +125,25 @@ Use the link for Jupyter Lab as a link for "Connect to a local runtime" in Colla
125125
- **`sft_qwen3_demo.ipynb`** → Qwen3-0.6B SFT training and evaluation on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k)
126126
- **`sft_llama3_demo.ipynb`** → Llama3.1-8B SFT training on [Hugging Face ultrachat_200k dataset](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
127127

128-
### GRPO Training
128+
### Reinforcement Learning (GRPO/GSPO) Training
129129

130-
- **`grpo_llama3_1_8b_demo.ipynb`** → GRPO training on math dataset (Colab/notebook)
130+
- **`rl_llama3_demo.ipynb`** → GRPO/GSPO training on math dataset (Colab/notebook)
131131

132-
#### GRPO Colab Usage
132+
#### GRPO/GSPO Colab Usage
133133

134-
For interactive GRPO training in Google Colab or Jupyter:
134+
For interactive GRPO or GSPO training in Google Colab or Jupyter:
135135

136-
1. **Open** `src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb`
136+
1. **Open** `src/MaxText/examples/rl_llama3_demo.ipynb`
137137
2. **Enable TPU runtime** (Runtime → Change runtime type → TPU)
138-
3. **Run cells** to train Llama3.1-8B with GRPO on GSM8K dataset
138+
3. **Set `LOSS_ALGO`** to `"grpo"` for GRPO or `"gspo-token"` for GSPO
139+
4. **Run cells** to train Llama3.1-8B with GRPO or GSPO on GSM8K dataset
139140

140-
#### GRPO Python Script Usage - local runs
141+
> **Note:** GRPO (Group Relative Policy Optimization) optimizes each token, while GSPO (Group Sequence Policy Optimization) optimizes the whole sequence. The difference is controlled by the `loss_algo` parameter.
142+
143+
#### GRPO/GSPO Python Script Usage - local runs
141144

142145
```bash
143-
# Llama3.1-8B-Instruct
146+
# Llama3.1-8B-Instruct with GRPO (default)
144147
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
145148
--model_name=llama3.1-8b \
146149
--tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
@@ -149,6 +152,16 @@ python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
149152
--base_output_directory=$OUTPUT_PATH \
150153
--hf_access_token=$HF_TOKEN
151154

155+
# Llama3.1-8B-Instruct with GSPO
156+
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
157+
--model_name=llama3.1-8b \
158+
--tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
159+
--load_parameters_path=gs://path/to/checkpoint/0/items \
160+
--run_name=$WORKLOAD \
161+
--base_output_directory=$OUTPUT_PATH \
162+
--hf_access_token=$HF_TOKEN \
163+
--loss_algo=gspo-token
164+
152165
# Qwen2.5-7B
153166
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
154167
--model_name=qwen2.5-7b \
@@ -158,7 +171,10 @@ python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
158171
--base_output_directory=$OUTPUT_PATH \
159172
--hf_access_token=$HF_TOKEN
160173
```
161-
#### GRPO Python Script Usage - cluster runs
174+
175+
> **Note:** To use GSPO instead of GRPO, add `--loss_algo=gspo-token` to the command. GRPO optimizes each token, while GSPO optimizes the whole sequence.
176+
177+
#### GRPO/GSPO Python Script Usage - cluster runs
162178

163179
For running on clusters, please refer to `maxtext/docs/tutorials/grpo_with_pathways.md`
164180

src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb

Lines changed: 0 additions & 218 deletions
This file was deleted.

0 commit comments

Comments
 (0)