|
1 | | -# grpo |
| 1 | +# GRPO: Group Relative Policy Optimization Algorithm |
| 2 | + |
| 3 | +Group Relative Policy Optimization (GRPO) is an algorithm proposed by Deepseek for training large language models with reinforcement learning. This repository aggregates and refactors **four distinct implementations** of GRPO, each demonstrating different approaches to the core algorithm while sharing common principles. |
| 4 | + |
| 5 | +## Algorithm |
| 6 | + |
| 7 | +The core GRPO algorithm follows these steps: |
| 8 | + |
| 9 | +1. For each training step, randomly sample $N$ questions $q_1, q_2, \cdots, q_N$. |
| 10 | +2. For each question $q_i$, sample $M$ answers $a_{i,1}, a_{i,2}, \cdots, a_{i,M}$. |
| 11 | +3. Compute the reward $r_{i,j}$ for each answer $a_{i,j}$. |
| 12 | +4. Compute group statistics for each question $q_i$: |
| 13 | + |
| 14 | +$$ |
| 15 | +\begin{alignedat}{2} |
| 16 | +&\mu_i &&\leftarrow \text{mean}(r_{i,1}, r_{i,2}, \cdots, r_{i,M}) \\ |
| 17 | +&\sigma_i &&\leftarrow \text{std}(r_{i,1}, r_{i,2}, \cdots, r_{i,M}) |
| 18 | +\end{alignedat} |
| 19 | +$$ |
| 20 | + |
| 21 | +5. For each token $t$ in answer $a_{i,j}$, compute advantage: |
| 22 | +$$A_{i,j}[t] \leftarrow \frac{r_{i,j} - \mu_i}{\sigma_i}$$ |
| 23 | + |
| 24 | +6. Update policy using PPO surrogate objective: |
| 25 | +$$\nabla_\theta \log \pi_\theta(a_{i,j}[t]) \cdot A_{i,j}[t]$$ |
| 26 | + |
| 27 | +## Implementations |
| 28 | + |
| 29 | +We provide four refactored implementations of GRPO, each with a different focus and design: |
| 30 | + |
| 31 | +### 1. [nanoAhaMoment](src/grpo/nanoAhaMoment) |
| 32 | + |
| 33 | +An implementation from [nanoAhaMoment](https://github.com/nanoAhaMoment/nanoAhaMoment), that separates each step of the GRPO loop into distinct components. It uses a rule-based reward function for a Countdown task and integrates with vLLM for efficient generation. |
| 34 | + |
| 35 | +- Modular pipeline with separated components |
| 36 | +- vLLM integration for efficient generation |
| 37 | +- DeepSpeed training backend |
| 38 | +- Format: `<think>...</think>\n<answer>...</answer>` |
| 39 | +- Rule-based reward functions for Countdown tasks |
| 40 | + |
| 41 | +### 2. [GRPO:Zero](src/grpo/GRPO-Zero) |
| 42 | + |
| 43 | +An implementation from [GRPO-Zero](https://github.com/policy-gradient/GRPO-Zero), that uses a separate server for the reference model to offload computation. It uses the GSM8K dataset and a combined reward for correctness and format. |
| 44 | + |
| 45 | +- Qwen2.5-3B-Instruct base model |
| 46 | +- Countdown-Tasks-3to4 dataset |
| 47 | +- Simplified training workflow |
| 48 | +- Reward Function: Combined reward for correctness and format |
| 49 | + |
| 50 | +### 3. [Simple GRPO](src/grpo/Simple_GRPO) |
| 51 | + |
| 52 | +An implementation from [Simple GRPO](https://github.com/lsdefine/simple_GRPO), that uses DeepSpeed for training and a reference model server. It features a policy gradient loss with KL penalty and reward normalization within groups. |
| 53 | + |
| 54 | +- Reference model server architecture |
| 55 | +- GSM8K dataset |
| 56 | +- KL divergence penalty term |
| 57 | +- Per-token advantage calculation |
| 58 | +- Distributed training support |
| 59 | +- Loss Calculation: `loss = -(policy_ratio * advantage - beta * kl_divergence)` |
| 60 | + |
| 61 | +### 4. [GRPO from Scratch](src/grpo/GRPO_from_Scratch) |
| 62 | + |
| 63 | +An implementation from ["The LM Book" by Andriy Burkov](https://github.com/aburkov/theLMbook/blob/main/GRPO.py), that demonstrates the core GRPO algorithm step-by-step. It uses a copy of the reference model and performs multiple updates per batch. |
| 64 | + |
| 65 | +- Periodic reference model updates |
| 66 | +- Multiple updates per batch (μ-PPO) |
| 67 | +- Comprehensive reward decomposition |
| 68 | +- Memory optimization techniques |
| 69 | +- Reward Function: Combined reward for correctness and format |
| 70 | + |
| 71 | +## Common Components |
| 72 | + |
| 73 | +All implementations share the following steps: |
| 74 | + |
| 75 | +- **Group Sampling**: For each prompt, multiple completions are generated to form a group. |
| 76 | +- **Reward Calculation**: Each completion receives a scalar reward, typically combining correctness and format adherence. |
| 77 | +- **Advantage Normalization**: Within each group, rewards are normalized to have zero mean and unit variance to form advantages. |
| 78 | +- **Policy Update**: The policy is updated using a policy gradient method (with or without clipping) and often includes a KL penalty to prevent deviation from a reference policy. |
| 79 | + |
| 80 | +## Variations |
| 81 | + |
| 82 | +The implementations have different variations in the following aspects: |
| 83 | + |
| 84 | +- Reward Functions: The implementations use different reward functions tailored to the task and different weights for format and correctness. |
| 85 | + - **Format Reward**: Enforces XML-style reasoning structure |
| 86 | + - **Correctness Reward**: Validates solution accuracy |
| 87 | + - **Combined Reward**: `R_total = R_format + R_correctness` |
| 88 | + |
| 89 | +- Reference Model Handling: Some implementations use a fixed reference model (via a separate server or a frozen copy) while others update the reference model periodically. |
| 90 | + |
| 91 | +- Training Framework: The implementations use different training frameworks (e.g., DeepSpeed, pure PyTorch) and optimization techniques (e.g., gradient checkpointing). |
| 92 | + |
| 93 | +- Batching and Generation: The approaches to generation (vLLM, Hugging Face transformers) and batching vary. |
0 commit comments