You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/tutorials/rl.md
+26-19Lines changed: 26 additions & 19 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -14,19 +14,22 @@
14
14
limitations under the License.
15
15
-->
16
16
17
-
# Try GRPO
17
+
# Reinforcement Learning on Single-Host TPUs
18
18
19
-
This tutorial demonstrates step-by-step instructions for setting up the environment and then training the Llama3.1 8B-IT model on the GSM8K math reasoning benchmark using Group Relative Policy Optimization (GRPO). GRPO can enhance your model's problem-solving skills on mathematical word problems, coding problems, etc.
19
+
This tutorial demonstrates step-by-step instructions for setting up the environment and then training the Llama3.1 8B-IT model on the GSM8K math reasoning dataset using a single host TPU-VM such as `v6e-8/v5p-8`.
20
20
21
-
GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by generating multiple responses for a given prompt, evaluating these responses using a reward model, and then calculating a relative advantage based on the group's performance to update the policy.
21
+
We utilize two RL algorithms, implemented via the Tunix library, to enhance the model's reasoning capabilities:
22
22
23
-
We use Tunix as the library for GRPO/GSPO.
24
-
And we use vLLM as the library for efficient model inference and generation.
23
+
***Group Relative Policy Optimization (GRPO)**: GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by generating multiple responses for a given prompt, evaluating these responses using a reward model, and then calculating a relative advantage based on the group's performance to update the policy.
25
24
26
-
In this tutorial we use a single host TPUVM such as `v6e-8/v5p-8`. Let's get started!
25
+
***Group Sequence Policy Optimization (GSPO)**: GSPO is an RL algorithm that improves training efficiency and performance of LLMs by using sequence-level importance ratios and operations. GSPO defines the importance ratio based on sequence likelihood and performs sequence-level clipping, rewarding, and optimization.
26
+
27
+
For efficient model inference and response generation during this process, we rely on the vLLM library.
28
+
29
+
Let's get started!
27
30
28
31
## Create virtual environment and Install MaxText dependencies
29
-
If you have already completed the [MaxText installation](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/install_maxtext.md), you can skip to the next section for vLLM and tpu-inference installations. Otherwise, please install MaxText using the following commands before proceeding.
32
+
If you have already completed the [MaxText installation](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/install_maxtext.md), you can skip to the next section for post-training dependencies installations. Otherwise, please install `MaxText` using the following commands before proceeding.
@@ -58,11 +61,11 @@ Primarily, it installs `vllm-tpu` which is [vllm](https://github.com/vllm-projec
58
61
59
62
### From Github
60
63
61
-
You can also locally git clone [tunix](https://github.com/google/tunix) and install using the instructions [here](https://github.com/google/tunix?tab=readme-ov-file#installation). Similarly install [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) from source following the instructions [here](https://docs.vllm.ai/projects/tpu/en/latest/getting_started/installation/#install-from-source)
64
+
You can also locally git clone [tunix](https://github.com/google/tunix) and install using the instructions [here](https://github.com/google/tunix?tab=readme-ov-file#installation). Similarly install [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) from source following the instructions [here](https://docs.vllm.ai/projects/tpu/en/latest/getting_started/installation/#install-from-source).
62
65
63
-
## Setup the following environment variables before running GRPO
66
+
## Setup environment variables
64
67
65
-
Setup following environment variables before running GRPO
68
+
Setup following environment variables before running GRPO/GSPO:
You can convert a Hugging Face checkpoint to MaxText format using the `src/MaxText/utils/ckpt_conversion/to_maxtext.py` script. This is useful if you have a pre-trained model from Hugging Face that you want to use with MaxText.
84
87
85
-
First, ensure you have the necessary dependencies installed. Then, run the conversion script on a CPU machine. For large models, it is recommended to use the --lazy_load_tensors flag to reduce memory usage during conversion. This command will download the Hugging Face model and convert it to the MaxText format, saving it to the specified GCS bucket.
88
+
First, ensure you have the necessary dependencies installed. Then, run the conversion script on a CPU machine. For large models, it is recommended to use the `--lazy_load_tensors` flag to reduce memory usage during conversion. This command will download the Hugging Face model and convert it to the MaxText format, saving it to the specified GCS bucket.
Copy file name to clipboardExpand all lines: docs/tutorials/rl_on_multi_host.md
+31-17Lines changed: 31 additions & 17 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -14,33 +14,28 @@
14
14
limitations under the License.
15
15
-->
16
16
17
-
# Try GRPO with Pathways!
17
+
# Reinforcement Learning on Multi-Host TPUs
18
18
19
-
This tutorial demonstrates step-by-step instructions for setting up the environment and then training the Llama3.1 70B-IT model on the GSM8K math reasoning benchmark using Group Relative Policy Optimization (GRPO). GRPO can enhance your model's problem-solving skills on mathematical word problems, coding problems, etc.
19
+
This tutorial demonstrates step-by-step instructions for setting up the environment and then training the Llama3.1 70B-IT model on the GSM8K math reasoning dataset using [Pathways for orchestration](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro)on multi-host TPU-VMs such as `v5p-128`.
20
20
21
-
GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by generating multiple responses for a given prompt, evaluating these responses using a reward model, and then calculating a relative advantage based on the group's performance to update the policy.
21
+
We utilize two RL algorithms, implemented via the Tunix library, to enhance the model's reasoning capabilities:
22
22
23
-
GSPO support
24
-
Some workloads prefer Group Sequence Policy Optimization (GSPO), which uses the same infrastructure but a different loss.
25
-
To switch from GRPO to GSPO, add the following override when invoking `train_rl.py` (or when building the `pyconfig` argv list):
26
-
```
27
-
loss_algo=gspo-token
28
-
```
29
-
No other changes are required—the rest of this tutorial applies equally to GSPO runs.
23
+
***Group Relative Policy Optimization (GRPO)**: GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by generating multiple responses for a given prompt, evaluating these responses using a reward model, and then calculating a relative advantage based on the group's performance to update the policy.
24
+
25
+
***Group Sequence Policy Optimization (GSPO)**: GSPO is an RL algorithm that improves training efficiency and performance of LLMs by using sequence-level importance ratios and operations. GSPO defines the importance ratio based on sequence likelihood and performs sequence-level clipping, rewarding, and optimization.
30
26
31
-
We use Tunix as the library for GRPO.
32
-
And we use vLLM as the library for efficient model inference and generation.
27
+
For efficient model inference and response generation during this process, we rely on the vLLM library.
33
28
34
-
Furthermore, we use Pathways for [orchestration](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro). Using Pathways, you can also run GRPO in a disaggregated mode where the trainer and the samplers are running on separate mesh. Try out the following recipe `v5p-64`. You can submit jobs to a Pathways enabled GKE cluster.
29
+
Let's get started!
35
30
36
31
## Create virtual environment and Install MaxText dependencies
37
32
Follow instructions in [Install MaxText](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/install_maxtext.md), but
38
33
recommend creating the virtual environment outside the `maxtext` directory.
39
34
40
35
41
-
## Setup the following environment variables before running GRPO
36
+
## Setup environment variables
42
37
43
-
Setup following environment variables before running GRPO
38
+
Setup following environment variables:
44
39
45
40
```bash
46
41
# -- Model configuration --
@@ -118,9 +113,11 @@ bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training PO
Please create a pathways ready GKE cluster as described [here](https://docs.cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/create-gke-cluster), and you can submit the `train_rl.py` script via [XPK](https://github.com/AI-Hypercomputer/xpk)
118
+
Please create a pathways ready GKE cluster as described [here](https://docs.cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/create-gke-cluster), and you can submit the `train_rl.py` script via [XPK](https://github.com/AI-Hypercomputer/xpk).
0 commit comments