Skip to content

Commit 497d9e3

Browse files
committed
Update documentation for GSPO
1 parent 60028c4 commit 497d9e3

File tree

8 files changed

+98
-86
lines changed

8 files changed

+98
-86
lines changed

README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,9 @@ Check out these getting started guides:
7878
* Supervised Fine Tuning (SFT)
7979
* [SFT on Single-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/sft.html)
8080
* [SFT on Multi-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/sft_on_multi_host.html)
81-
* Group Relative & Group Sequence Policy Optimization (GRPO & GSPO)
82-
* [GRPO on Single-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/grpo.html)
83-
* [GRPO on Multi-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/grpo_with_pathways.html)
84-
* [GSPO](https://maxtext.readthedocs.io/en/latest/tutorials/grpo.html#run-gspo) (pass `loss_algo=gspo-token` to run GSPO)
81+
* Reinforcement Learning (RL)
82+
* [RL on Single-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/rl.html)
83+
* [RL on Multi-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/rl_on_multi_host.html)
8584

8685
### Model library
8786

dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ ARG MODE
1818

1919
ENV MODE=$MODE
2020

21-
RUN echo "Installing Post-Training dependencies (vLLM, tpu-common, tunix) with MODE=${MODE}"
21+
RUN echo "Installing Post-Training dependencies (vLLM, tpu-inference, tunix) with MODE=${MODE}"
2222

2323
# Uninstall existing jax to avoid conflicts
2424
RUN pip uninstall -y jax jaxlib libtpu

dependencies/dockerfiles/maxtext_post_training_local_dependencies.Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ FROM ${BASEIMAGE}
1717
ARG MODE
1818
ENV MODE=$MODE
1919

20-
RUN echo "Installing GRPO dependencies (vLLM, tpu-inference) with MODE=${MODE}"
20+
RUN echo "Installing Post-Training dependencies (tunix, vLLM, tpu-inference) with MODE=${MODE}"
2121
RUN pip uninstall -y jax jaxlib libtpu
2222

2323
RUN pip install aiohttp==3.12.15

docs/tutorials/post_training_index.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ MaxText was co-designed with key Google led innovations to provide a unified pos
1818

1919
## Supported techniques & models
2020

21-
- **SFT (Supervised Fine-Tuning)** [(link)](https://maxtext.readthedocs.io/en/latest/tutorials/sft.html)
22-
- Supports all MaxText models
23-
- **Multimodal SFT** [(link)](https://maxtext.readthedocs.io/en/latest/guides/multimodal.html)
24-
- **GRPO (Group Relative Policy Optimization)** [(link)](https://maxtext.readthedocs.io/en/latest/tutorials/grpo.html)
25-
- Llama 3.1 8B
26-
- Llama 3.1 70B
27-
- **GSPO-token**
28-
- Coming soon
21+
- **SFT (Supervised Fine-Tuning)**
22+
* [SFT on Single-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/sft.html)
23+
* [SFT on Multi-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/sft_on_multi_host.html)
24+
- **Multimodal SFT**
25+
* [Multimodal Support](https://maxtext.readthedocs.io/en/latest/guides/multimodal.html)
26+
- **Reinforcement Learning (RL)**
27+
* [RL on Single-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/rl.html)
28+
* [RL on Multi-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/rl_on_multi_host.html)
2929

3030
## Step by step RL
3131

@@ -58,8 +58,8 @@ Start your Post-Training journey through quick experimentation with our [Google
5858
5959
full_finetuning.md
6060
how_to_run_colabs.md
61-
grpo.md
61+
rl.md
6262
sft.md
6363
sft_on_multi_host.md
64-
grpo_with_pathways.md
64+
rl_on_multi_host.md
6565
```
Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,22 @@
1414
limitations under the License.
1515
-->
1616

17-
# Try GRPO
17+
# Reinforcement Learning on Single-Host TPUs
1818

19-
This tutorial demonstrates step-by-step instructions for setting up the environment and then training the Llama3.1 8B-IT model on the GSM8K math reasoning benchmark using Group Relative Policy Optimization (GRPO). GRPO can enhance your model's problem-solving skills on mathematical word problems, coding problems, etc.
19+
This tutorial demonstrates step-by-step instructions for setting up the environment and then training the Llama3.1 8B-IT model on the GSM8K math reasoning dataset using a single host TPU-VM such as `v6e-8/v5p-8`.
2020

21-
GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by generating multiple responses for a given prompt, evaluating these responses using a reward model, and then calculating a relative advantage based on the group's performance to update the policy.
21+
We utilize two RL algorithms, implemented via the Tunix library, to enhance the model's reasoning capabilities:
2222

23-
We use Tunix as the library for GRPO/GSPO.
24-
And we use vLLM as the library for efficient model inference and generation.
23+
* **Group Relative Policy Optimization (GRPO)**: GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by generating multiple responses for a given prompt, evaluating these responses using a reward model, and then calculating a relative advantage based on the group's performance to update the policy.
2524

26-
In this tutorial we use a single host TPUVM such as `v6e-8/v5p-8`. Let's get started!
25+
* **Group Sequence Policy Optimization (GSPO)**: GSPO is an RL algorithm that improves training efficiency and performance of LLMs by using sequence-level importance ratios and operations. GSPO defines the importance ratio based on sequence likelihood and performs sequence-level clipping, rewarding, and optimization.
26+
27+
For efficient model inference and response generation during this process, we rely on the vLLM library.
28+
29+
Let's get started!
2730

2831
## Create virtual environment and Install MaxText dependencies
29-
If you have already completed the [MaxText installation](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/install_maxtext.md), you can skip to the next section for vLLM and tpu-inference installations. Otherwise, please install MaxText using the following commands before proceeding.
32+
If you have already completed the [MaxText installation](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/install_maxtext.md), you can skip to the next section for post-training dependencies installations. Otherwise, please install `MaxText` using the following commands before proceeding.
3033
```bash
3134
# 1. Clone the repository
3235
git clone https://github.com/AI-Hypercomputer/maxtext.git
@@ -43,7 +46,7 @@ uv pip install -e .[tpu] --resolution=lowest
4346
install_maxtext_github_deps
4447
```
4548

46-
## vLLM and tpu-inference installations
49+
## Install Post-Training dependencies
4750

4851
### From PyPI releases
4952

@@ -58,11 +61,11 @@ Primarily, it installs `vllm-tpu` which is [vllm](https://github.com/vllm-projec
5861

5962
### From Github
6063

61-
You can also locally git clone [tunix](https://github.com/google/tunix) and install using the instructions [here](https://github.com/google/tunix?tab=readme-ov-file#installation). Similarly install [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) from source following the instructions [here](https://docs.vllm.ai/projects/tpu/en/latest/getting_started/installation/#install-from-source)
64+
You can also locally git clone [tunix](https://github.com/google/tunix) and install using the instructions [here](https://github.com/google/tunix?tab=readme-ov-file#installation). Similarly install [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) from source following the instructions [here](https://docs.vllm.ai/projects/tpu/en/latest/getting_started/installation/#install-from-source).
6265

63-
## Setup the following environment variables before running GRPO
66+
## Setup environment variables
6467

65-
Setup following environment variables before running GRPO
68+
Setup following environment variables before running GRPO/GSPO:
6669

6770
```bash
6871
# -- Model configuration --
@@ -82,7 +85,7 @@ export MAXTEXT_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/0/items
8285

8386
You can convert a Hugging Face checkpoint to MaxText format using the `src/MaxText/utils/ckpt_conversion/to_maxtext.py` script. This is useful if you have a pre-trained model from Hugging Face that you want to use with MaxText.
8487

85-
First, ensure you have the necessary dependencies installed. Then, run the conversion script on a CPU machine. For large models, it is recommended to use the --lazy_load_tensors flag to reduce memory usage during conversion. This command will download the Hugging Face model and convert it to the MaxText format, saving it to the specified GCS bucket.
88+
First, ensure you have the necessary dependencies installed. Then, run the conversion script on a CPU machine. For large models, it is recommended to use the `--lazy_load_tensors` flag to reduce memory usage during conversion. This command will download the Hugging Face model and convert it to the MaxText format, saving it to the specified GCS bucket.
8689

8790
```bash
8891
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
@@ -108,7 +111,7 @@ python3 -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml \
108111

109112
## Run GRPO
110113

111-
Finally, run the command
114+
Run the following command for GRPO:
112115

113116
```
114117
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
@@ -120,19 +123,16 @@ python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
120123
hf_access_token=${HF_TOKEN}
121124
```
122125

123-
The overview of the what this run will do is as follows:
126+
The overview of what this run will do is as follows:
124127

125128
1. We load a policy model and a reference model. Both are copies of `Llama3.1-8b-Instruct`.
126129
2. Evaluate the policy model's performance on GSM8K math reasoning benchmark.
127130
3. Train the policy model using GRPO.
128-
4. Evaluate the policy model's performance on GSM8K math reasoning benchmark after the post-training with GRPO.
129-
130-
GSPO (Group Sequence Policy Optimization)
131-
MaxText can also run the GSPO variant by setting `loss_algo=gspo-token` when invoking `train_rl.py` (or when constructing the pyconfig argv list).
131+
4. Evaluate the policy model's performance on GSM8K math reasoning benchmark after the post-training with GRPO.
132132

133133
## Run GSPO
134134

135-
Finally, run the command
135+
Run the following command for GSPO:
136136

137137
```
138138
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
@@ -145,3 +145,10 @@ python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
145145
loss_algo=gspo-token
146146
```
147147

148+
The overview of what this run will do is as follows:
149+
150+
1. We load a policy model and a reference model. Both are copies of `Llama3.1-8b-Instruct`.
151+
2. Evaluate the policy model's performance on GSM8K math reasoning benchmark.
152+
3. Train the policy model using GSPO.
153+
4. Evaluate the policy model's performance on GSM8K math reasoning benchmark after the post-training with GSPO.
154+
Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,28 @@
1414
limitations under the License.
1515
-->
1616

17-
# Try GRPO with Pathways!
17+
# Reinforcement Learning on Multi-Host TPUs
1818

19-
This tutorial demonstrates step-by-step instructions for setting up the environment and then training the Llama3.1 70B-IT model on the GSM8K math reasoning benchmark using Group Relative Policy Optimization (GRPO). GRPO can enhance your model's problem-solving skills on mathematical word problems, coding problems, etc.
19+
This tutorial demonstrates step-by-step instructions for setting up the environment and then training the Llama3.1 70B-IT model on the GSM8K math reasoning dataset using [Pathways for orchestration](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro) on multi-host TPU-VMs such as `v5p-128`.
2020

21-
GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by generating multiple responses for a given prompt, evaluating these responses using a reward model, and then calculating a relative advantage based on the group's performance to update the policy.
21+
We utilize two RL algorithms, implemented via the Tunix library, to enhance the model's reasoning capabilities:
2222

23-
GSPO support
24-
Some workloads prefer Group Sequence Policy Optimization (GSPO), which uses the same infrastructure but a different loss.
25-
To switch from GRPO to GSPO, add the following override when invoking `train_rl.py` (or when building the `pyconfig` argv list):
26-
```
27-
loss_algo=gspo-token
28-
```
29-
No other changes are required—the rest of this tutorial applies equally to GSPO runs.
23+
* **Group Relative Policy Optimization (GRPO)**: GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by generating multiple responses for a given prompt, evaluating these responses using a reward model, and then calculating a relative advantage based on the group's performance to update the policy.
24+
25+
* **Group Sequence Policy Optimization (GSPO)**: GSPO is an RL algorithm that improves training efficiency and performance of LLMs by using sequence-level importance ratios and operations. GSPO defines the importance ratio based on sequence likelihood and performs sequence-level clipping, rewarding, and optimization.
3026

31-
We use Tunix as the library for GRPO.
32-
And we use vLLM as the library for efficient model inference and generation.
27+
For efficient model inference and response generation during this process, we rely on the vLLM library.
3328

34-
Furthermore, we use Pathways for [orchestration](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro). Using Pathways, you can also run GRPO in a disaggregated mode where the trainer and the samplers are running on separate mesh. Try out the following recipe `v5p-64`. You can submit jobs to a Pathways enabled GKE cluster.
29+
Let's get started!
3530

3631
## Create virtual environment and Install MaxText dependencies
3732
Follow instructions in [Install MaxText](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/install_maxtext.md), but
3833
recommend creating the virtual environment outside the `maxtext` directory.
3934

4035

41-
## Setup the following environment variables before running GRPO
36+
## Setup environment variables
4237

43-
Setup following environment variables before running GRPO
38+
Setup following environment variables:
4439

4540
```bash
4641
# -- Model configuration --
@@ -118,9 +113,11 @@ bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training PO
118113
bash dependencies/scripts/docker_upload_runner.sh CLOUD_IMAGE_NAME=${CLOUD_IMAGE_NAME}
119114
```
120115

121-
## Submit your jobs
116+
## Submit your RL workload via Pathways
122117

123-
Please create a pathways ready GKE cluster as described [here](https://docs.cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/create-gke-cluster), and you can submit the `train_rl.py` script via [XPK](https://github.com/AI-Hypercomputer/xpk)
118+
Please create a pathways ready GKE cluster as described [here](https://docs.cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/create-gke-cluster), and you can submit the `train_rl.py` script via [XPK](https://github.com/AI-Hypercomputer/xpk).
119+
120+
### Submit GRPO workload
124121
```
125122
xpk workload create-pathways --workload $WORKLOAD \
126123
--docker-image <path/to/gcr.io> --cluster $TPU_CLUSTER \
@@ -135,3 +132,20 @@ python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
135132
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
136133
hf_access_token=$HF_TOKEN"
137134
```
135+
136+
### Submit GSPO workload
137+
```
138+
xpk workload create-pathways --workload $WORKLOAD \
139+
--docker-image <path/to/gcr.io> --cluster $TPU_CLUSTER \
140+
--tpu-type=$TPU_TYPE --num-slices=1 --zone=$ZONE \
141+
--project=$PROJECT_ID --priority=high \
142+
--command "TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \
143+
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
144+
model_name=${MODEL} \
145+
tokenizer_path=${TOKENIZER} \
146+
load_parameters_path=${MAXTEXT_CKPT_PATH} \
147+
run_name=${RUN_NAME} \
148+
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
149+
hf_access_token=$HF_TOKEN \
150+
loss_algo=gspo-token"
151+
```

src/MaxText/examples/rl_llama3_demo.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
"## Setup\n",
6464
"\n",
6565
"Install dependencies and set up the environment:\n",
66-
"https://maxtext.readthedocs.io/en/latest/tutorials/grpo.html#from-github"
66+
"https://maxtext.readthedocs.io/en/latest/tutorials/rl.html#from-github"
6767
]
6868
},
6969
{
@@ -256,7 +256,7 @@
256256
"source": [
257257
"## 📚 Learn More\n",
258258
"\n",
259-
"- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/grpo.html#run-grpo\n",
259+
"- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/rl.html#run-grpo\n",
260260
"- **Configuration**: See `src/MaxText/configs/rl.yml` for all available options\n",
261261
"- **Documentation**: Check `src/MaxText/rl/train_rl.py` for the `rl_train` function implementation"
262262
]

0 commit comments

Comments
 (0)