Skip to content

Commit 721bf8f

Browse files
committed
Update README.md
1 parent d6aabad commit 721bf8f

File tree

1 file changed

+93
-1
lines changed

1 file changed

+93
-1
lines changed

README.md

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,93 @@
1-
# grpo
1+
# GRPO: Group Relative Policy Optimization Algorithm
2+
3+
Group Relative Policy Optimization (GRPO) is an algorithm proposed by Deepseek for training large language models with reinforcement learning. This repository aggregates and refactors **four distinct implementations** of GRPO, each demonstrating different approaches to the core algorithm while sharing common principles.
4+
5+
## Algorithm
6+
7+
The core GRPO algorithm follows these steps:
8+
9+
1. For each training step, randomly sample $N$ questions $q_1, q_2, \cdots, q_N$.
10+
2. For each question $q_i$, sample $M$ answers $a_{i,1}, a_{i,2}, \cdots, a_{i,M}$.
11+
3. Compute the reward $r_{i,j}$ for each answer $a_{i,j}$.
12+
4. Compute group statistics for each question $q_i$:
13+
14+
$$
15+
\begin{alignedat}{2}
16+
&\mu_i &&\leftarrow \text{mean}(r_{i,1}, r_{i,2}, \cdots, r_{i,M}) \\
17+
&\sigma_i &&\leftarrow \text{std}(r_{i,1}, r_{i,2}, \cdots, r_{i,M})
18+
\end{alignedat}
19+
$$
20+
21+
5. For each token $t$ in answer $a_{i,j}$, compute advantage:
22+
$$A_{i,j}[t] \leftarrow \frac{r_{i,j} - \mu_i}{\sigma_i}$$
23+
24+
6. Update policy using PPO surrogate objective:
25+
$$\nabla_\theta \log \pi_\theta(a_{i,j}[t]) \cdot A_{i,j}[t]$$
26+
27+
## Implementations
28+
29+
We provide four refactored implementations of GRPO, each with a different focus and design:
30+
31+
### 1. [nanoAhaMoment](src/grpo/nanoAhaMoment)
32+
33+
An implementation from [nanoAhaMoment](https://github.com/nanoAhaMoment/nanoAhaMoment), that separates each step of the GRPO loop into distinct components. It uses a rule-based reward function for a Countdown task and integrates with vLLM for efficient generation.
34+
35+
- Modular pipeline with separated components
36+
- vLLM integration for efficient generation
37+
- DeepSpeed training backend
38+
- Format: `<think>...</think>\n<answer>...</answer>`
39+
- Rule-based reward functions for Countdown tasks
40+
41+
### 2. [GRPO:Zero](src/grpo/GRPO-Zero)
42+
43+
An implementation from [GRPO-Zero](https://github.com/policy-gradient/GRPO-Zero), that uses a separate server for the reference model to offload computation. It uses the GSM8K dataset and a combined reward for correctness and format.
44+
45+
- Qwen2.5-3B-Instruct base model
46+
- Countdown-Tasks-3to4 dataset
47+
- Simplified training workflow
48+
- Reward Function: Combined reward for correctness and format
49+
50+
### 3. [Simple GRPO](src/grpo/Simple_GRPO)
51+
52+
An implementation from [Simple GRPO](https://github.com/lsdefine/simple_GRPO), that uses DeepSpeed for training and a reference model server. It features a policy gradient loss with KL penalty and reward normalization within groups.
53+
54+
- Reference model server architecture
55+
- GSM8K dataset
56+
- KL divergence penalty term
57+
- Per-token advantage calculation
58+
- Distributed training support
59+
- Loss Calculation: `loss = -(policy_ratio * advantage - beta * kl_divergence)`
60+
61+
### 4. [GRPO from Scratch](src/grpo/GRPO_from_Scratch)
62+
63+
An implementation from ["The LM Book" by Andriy Burkov](https://github.com/aburkov/theLMbook/blob/main/GRPO.py), that demonstrates the core GRPO algorithm step-by-step. It uses a copy of the reference model and performs multiple updates per batch.
64+
65+
- Periodic reference model updates
66+
- Multiple updates per batch (μ-PPO)
67+
- Comprehensive reward decomposition
68+
- Memory optimization techniques
69+
- Reward Function: Combined reward for correctness and format
70+
71+
## Common Components
72+
73+
All implementations share the following steps:
74+
75+
- **Group Sampling**: For each prompt, multiple completions are generated to form a group.
76+
- **Reward Calculation**: Each completion receives a scalar reward, typically combining correctness and format adherence.
77+
- **Advantage Normalization**: Within each group, rewards are normalized to have zero mean and unit variance to form advantages.
78+
- **Policy Update**: The policy is updated using a policy gradient method (with or without clipping) and often includes a KL penalty to prevent deviation from a reference policy.
79+
80+
## Variations
81+
82+
The implementations have different variations in the following aspects:
83+
84+
- Reward Functions: The implementations use different reward functions tailored to the task and different weights for format and correctness.
85+
- **Format Reward**: Enforces XML-style reasoning structure
86+
- **Correctness Reward**: Validates solution accuracy
87+
- **Combined Reward**: `R_total = R_format + R_correctness`
88+
89+
- Reference Model Handling: Some implementations use a fixed reference model (via a separate server or a frozen copy) while others update the reference model periodically.
90+
91+
- Training Framework: The implementations use different training frameworks (e.g., DeepSpeed, pure PyTorch) and optimization techniques (e.g., gradient checkpointing).
92+
93+
- Batching and Generation: The approaches to generation (vLLM, Hugging Face transformers) and batching vary.

0 commit comments

Comments
 (0)