You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+2-2Lines changed: 2 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -22,7 +22,7 @@
22
22
23
23
MaxText is a high performance, highly scalable, open-source LLM library and reference implementation written in pure Python/[JAX](https://docs.jax.dev/en/latest/jax-101.html) and targeting Google Cloud TPUs and GPUs for training.
24
24
25
-
MaxText provides a library of high performance models to choose from, including Gemma, Llama, DeepSeek, Qwen, and Mistral. For each of these models, MaxText supports pre-training (up to tens of thousands of chips) and scalable post-training, with popular techniques like Supervised Fine-Tuning (SFT) and Group Relative Policy Optimization (GRPO, a type of Reinforcement Learning).
25
+
MaxText provides a library of high performance models to choose from, including Gemma, Llama, DeepSeek, Qwen, and Mistral. For each of these models, MaxText supports pre-training (up to tens of thousands of chips) and scalable post-training, with popular techniques like Supervised Fine-Tuning (SFT) and Group Relative Policy Optimization (GRPO, a type of Reinforcement Learning) and Group Sequence Policy Optimization (GSPO, a type of Reinforcement Learning).
26
26
27
27
MaxText achieves high Model FLOPs Utilization (MFU) and tokens/second from single host to very large clusters while staying simple and largely "optimization-free" thanks to the power of JAX and the XLA compiler.
28
28
@@ -70,7 +70,7 @@ Our goal is to provide a variety of models (dimension “a”) and techniques (d
70
70
Check out these getting started guides:
71
71
72
72
*[SFT](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/llama3.1/8b/run_sft.sh) (Supervised Fine Tuning)
Copy file name to clipboardExpand all lines: docs/tutorials/grpo.md
+20-1Lines changed: 20 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -20,7 +20,7 @@ This tutorial demonstrates step-by-step instructions for setting up the environm
20
20
21
21
GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by generating multiple responses for a given prompt, evaluating these responses using a reward model, and then calculating a relative advantage based on the group's performance to update the policy.
22
22
23
-
We use Tunix as the library for GRPO.
23
+
We use Tunix as the library for GRPO/GSPO.
24
24
And we use vLLM as the library for efficient model inference and generation.
25
25
26
26
In this tutorial we use a single host TPUVM such as `v6e-8/v5p-8`. Let's get started!
@@ -112,3 +112,22 @@ The overview of the what this run will do is as follows:
112
112
2. Evaluate the policy model's performance on GSM8K math reasoning benchmark.
113
113
3. Train the policy model using GRPO.
114
114
4. Evaluate the policy model's performance on GSM8K math reasoning benchmark after the post-training with GRPO.
115
+
116
+
GSPO (Group Sequence Policy Optimization)
117
+
MaxText can also run the GSPO variant by setting `loss_algo=gspo-token` when invoking `train_rl.py` (or when constructing the pyconfig argv list).
Copy file name to clipboardExpand all lines: docs/tutorials/grpo_with_pathways.md
+8Lines changed: 8 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -20,6 +20,14 @@ This tutorial demonstrates step-by-step instructions for setting up the environm
20
20
21
21
GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by generating multiple responses for a given prompt, evaluating these responses using a reward model, and then calculating a relative advantage based on the group's performance to update the policy.
22
22
23
+
GSPO support
24
+
Some workloads prefer Group Sequence Policy Optimization (GSPO), which uses the same infrastructure but a different loss.
25
+
To switch from GRPO to GSPO, add the following override when invoking `train_rl.py` (or when building the `pyconfig` argv list):
26
+
```
27
+
loss_algo=gspo-token
28
+
```
29
+
No other changes are required—the rest of this tutorial applies equally to GSPO runs.
30
+
23
31
We use Tunix as the library for GRPO.
24
32
And we use vLLM as the library for efficient model inference and generation.
Copy file name to clipboardExpand all lines: docs/tutorials/how_to_run_colabs.md
+26-10Lines changed: 26 additions & 10 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -59,7 +59,7 @@ Upload notebooks or mount your GitHub repo
59
59
2. Try:
60
60
-`sft_qwen3_demo.ipynb`
61
61
-`sft_llama3_demo.ipynb`
62
-
-`grpo_llama3_demo.ipynb`
62
+
-`rl_llama3_demo.ipynb` (GRPO/GSPO training)
63
63
64
64
65
65
> ⚡ **Tip:** If Colab disconnects, re-enable TPU and re-run setup cells. Save checkpoints to GCS or Drive.
@@ -125,22 +125,25 @@ Use the link for Jupyter Lab as a link for "Connect to a local runtime" in Colla
125
125
-**`sft_qwen3_demo.ipynb`** → Qwen3-0.6B SFT training and evaluation on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k)
126
126
-**`sft_llama3_demo.ipynb`** → Llama3.1-8B SFT training on [Hugging Face ultrachat_200k dataset](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
127
127
128
-
### GRPO Training
128
+
### Reinforcement Learning (GRPO/GSPO) Training
129
129
130
-
-**`grpo_llama3_1_8b_demo.ipynb`** → GRPO training on math dataset (Colab/notebook)
130
+
-**`rl_llama3_demo.ipynb`** → GRPO/GSPO training on math dataset (Colab/notebook)
131
131
132
-
#### GRPO Colab Usage
132
+
#### GRPO/GSPO Colab Usage
133
133
134
-
For interactive GRPO training in Google Colab or Jupyter:
134
+
For interactive GRPO or GSPO training in Google Colab or Jupyter:
2.**Enable TPU runtime** (Runtime → Change runtime type → TPU)
138
-
3.**Run cells** to train Llama3.1-8B with GRPO on GSM8K dataset
138
+
3.**Set `LOSS_ALGO`** to `"grpo"` for GRPO or `"gspo-token"` for GSPO
139
+
4.**Run cells** to train Llama3.1-8B with GRPO or GSPO on GSM8K dataset
139
140
140
-
#### GRPO Python Script Usage - local runs
141
+
> **Note:** GRPO (Group Relative Policy Optimization) optimizes each token, while GSPO (Group Sequence Policy Optimization) optimizes the whole sequence. The difference is controlled by the `loss_algo` parameter.
> **Note:** To use GSPO instead of GRPO, add `--loss_algo=gspo-token` to the command. GRPO optimizes each token, while GSPO optimizes the whole sequence.
176
+
177
+
#### GRPO/GSPO Python Script Usage - cluster runs
162
178
163
179
For running on clusters, please refer to `maxtext/docs/tutorials/grpo_with_pathways.md`
0 commit comments