Skip to content

Commit 315364c

Browse files
committed
Add grpo_zero and lm_book implementations
1 parent 4ae0bb9 commit 315364c

File tree

17 files changed

+2203
-0
lines changed

17 files changed

+2203
-0
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# GRPO from Scratch - LM Book by Andriy Burkov
2+
3+
Implementation of the **GRPO (Group Relative Policy Optimization)** training algorithm from:
4+
[https://github.com/aburkov/theLMbook/blob/main/GRPO.py](https://github.com/aburkov/theLMbook/blob/main/GRPO.py).
5+
6+
We refactor the original code into a **modular and easy-to-understand layout**.
7+
Each major step of the training process is separated into its own file for clarity and extensibility.
8+
9+
### Steps in the Training Process
10+
11+
1. **Initialization and Setup**:
12+
- Set random seeds for reproducibility across Python, NumPy, and PyTorch.
13+
- Load the pre-trained language model (Qwen2.5-0.5B-Instruct) and tokenizer.
14+
- Configure the model: set pad token to EOS token, enable memory optimizations (gradient checkpointing, disable cache, and ensure input gradients).
15+
- Prepare the GSM8K dataset: format each example with a system prompt and the question, and extract the answer.
16+
2. **Initial Evaluation**:
17+
- Evaluate the model's initial performance on a small subset of the training data (before any fine-tuning).
18+
- The evaluation function generates responses for each prompt, extracts answers, and compares them to the expected answers using multiple methods (exact match, single number extraction, last number extraction).
19+
3. **GRPO Training Loop**:
20+
- The training is divided into iterations (outer loop). In each iteration:
21+
- Create a reference model by making a deep copy of the current policy model (and freeze its parameters).
22+
- Set up an optimizer for the policy model.
23+
- For a specified number of steps (inner loop):
24+
- Sample a batch of training examples.
25+
- Generate multiple completions (rollouts) for each prompt in the batch using the current policy model.
26+
- For each generated completion, compute the log probabilities under the current policy and the reference model (without gradients).
27+
- Format the completions for reward computation.
28+
- Perform multiple GRPO updates (μ times) on the same batch of rollouts:
29+
- Compute rewards using the combined reward function (correctness and format).
30+
- Calculate group-relative advantages: within each group of completions for the same prompt, normalize the rewards by subtracting the group mean and dividing by the group standard deviation.
31+
- Compute the current log probabilities (with gradients) for the generated completions.
32+
- Calculate the policy ratio (exponential of the difference between current and old log probabilities).
33+
- Compute the surrogate loss (clipped to avoid large updates) and the KL divergence penalty (to prevent the policy from deviating too far from the reference model).
34+
- Combine the losses and update the policy model's parameters.
35+
4. **Final Evaluation and Saving**:
36+
- After completing all iterations, evaluate the fine-tuned model on the same evaluation subset.
37+
- Calculate the improvement in accuracy.
38+
- Save the fine-tuned model and tokenizer.
39+
40+
## Code Structure
41+
42+
We refactor the original implementation into a modular, readable, and extensible format. Each component corresponds to a specific phase in the GRPO loop.
43+
44+
```bash
45+
andriy_burkov_lm_book/
46+
├── train_grpo.py
47+
├── data_process.py
48+
├── reward_functions.py
49+
├── evaluation.py
50+
└── completions.py
51+
```
52+
53+
- **`prepare_dataset`**:
54+
- Loads the GSM8K dataset and formats each example into a prompt string (combining system message and user question) and extracts the answer.
55+
- **`evaluate_model`**:
56+
- Evaluates the model by generating responses for each evaluation example.
57+
- Extracts the predicted answer and compares it to the expected answer using multiple matching strategies (exact string, single number, last number).
58+
- Prints detailed results and returns the accuracy.
59+
- **`correctness_reward`**:
60+
- Assigns a reward (0.0, 1.5, or 2.0) based on the correctness of the generated answer compared to the expected answer. Uses exact matching and numeric equivalence.
61+
- **`format_reward`**:
62+
- Assigns a reward (up to 0.8) for adhering to the required XML format (presence of `<reasoning>`, `</reasoning>`, `<answer>`, and `</answer>` tags).
63+
- **`combined_reward`**:
64+
- Sums the correctness reward and format reward for a total reward in the range [0.0, 2.8].
65+
- **`generate_completions`**:
66+
- Generates multiple completions for each prompt in a batch using the current model.
67+
- Returns the tokenized prompts and completions, along with masks that ignore tokens after the first end-of-sequence (EOS) token.
68+
- **`generate_rollout_data`**:
69+
- Uses `generate_completions` to generate rollouts (completions) for a batch of prompts.
70+
- Computes log probabilities for these completions under both the current policy and the reference model (without gradients).
71+
- Returns a dictionary containing the rollout data (inputs, masks, log probabilities, etc.).
72+
- **`compute_group_relative_advantages`**:
73+
- Groups the rewards by prompt (each group has multiple completions for the same prompt).
74+
- For each group, normalizes the rewards by subtracting the group mean and dividing by the group standard deviation (adding a small epsilon to avoid division by zero).
75+
- Returns the normalized advantages for each completion.
76+
- **`maximize_grpo_objective`**:
77+
- The core function that computes the GRPO loss and updates the model.
78+
- Computes current log probabilities (with gradients).
79+
- Computes the policy ratio (current log probability divided by old log probability, exponentiated).
80+
- Computes rewards and then group-relative advantages.
81+
- Computes the surrogate loss (min of two terms: unclipped and clipped) and the KL penalty.
82+
- Combines them and performs a gradient update.
83+
- **`train_with_grpo`**:
84+
- Orchestrates the entire GRPO training process: sets up the reference model, optimizer, and loops over iterations and steps.
85+
- For each step, it generates rollout data and performs multiple GRPO updates.
86+
- **`optimize_model_memory`**:
87+
- Applies memory optimization techniques: disables caching, enables gradient checkpointing, and ensures input gradients are required.
88+
- **`main`**:
89+
- The main function that ties everything together: sets up the model, tokenizer, dataset, runs initial evaluation, performs GRPO training, runs final evaluation, and saves the model.

src/grpo/andriy_burkov_lm_book/__init__.py

Whitespace-only changes.
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# This code is based on the implementation from: https://github.com/aburkov/theLMbook/blob/main/GRPO.py
2+
3+
import torch
4+
import torch.nn.functional as F
5+
6+
7+
def selective_log_softmax(logits, input_ids):
8+
"""Compute the log probabilities for the tokens specified in input_ids using a selective log-softmax.
9+
10+
Args:
11+
logits (torch.Tensor): A tensor of shape (batch_size, seq_len, vocab_size) containing raw logits from the model.
12+
input_ids (torch.Tensor): A tensor of shape (batch_size, seq_len) containing the token indices for which we want the log probabilities.
13+
14+
Returns:
15+
torch.Tensor: A tensor of shape (batch_size, seq_len) where each element is the log probability
16+
corresponding to the token in input_ids at that position.
17+
18+
Explanation:
19+
1. F.log_softmax is applied along the vocabulary dimension (dim=-1) to convert logits into log probabilities.
20+
2. The tensor input_ids is reshaped (via unsqueeze) to have an extra dimension so that we can use it as indices
21+
in the log_probs tensor.
22+
3. torch.gather collects the log probability at the index specified in input_ids for each position.
23+
4. Finally, squeeze(-1) removes the extra dimension, returning a tensor with the same shape as input_ids.
24+
"""
25+
# Convert raw logits into log probabilities along the vocabulary axis.
26+
log_probs = F.log_softmax(logits, dim=-1) # Shape: (batch_size, seq_len, vocab_size)
27+
28+
# Reshape input_ids from (batch_size, seq_len) to (batch_size, seq_len, 1) for gathering.
29+
# Then, gather the log probability for each token in input_ids.
30+
selected_log_probs = log_probs.gather(dim=-1, index=input_ids.unsqueeze(-1))
31+
32+
# Remove the extra last dimension to get back to shape (batch_size, seq_len).
33+
return selected_log_probs.squeeze(-1)
34+
35+
36+
def compute_log_probabilities(model, input_ids, attention_mask, logits_to_keep):
37+
"""Compute per-token log probabilities for a subset of tokens (typically the completion tokens).
38+
39+
Args:
40+
model: The language model to use.
41+
input_ids (torch.Tensor): Tensor of shape (batch_size, total_seq_len) containing token ids
42+
for both prompt and completion.
43+
attention_mask (torch.Tensor): Tensor of shape (batch_size, total_seq_len) indicating which tokens are real (1) or padding (0).
44+
logits_to_keep (int): Number of tokens (from the completion part) for which we need log probabilities.
45+
46+
Returns:
47+
torch.Tensor: Log probabilities for the last `logits_to_keep` tokens of each sequence.
48+
49+
Explanation:
50+
1. We call the model with logits_to_keep + 1 so that the model outputs one extra logit than needed.
51+
This is common in next-token prediction setups.
52+
2. We slice off the last logit along the sequence dimension because it does not correspond to any input token.
53+
3. We then restrict both the input_ids and logits to the last logits_to_keep tokens, which should
54+
correspond to the generated completion portion.
55+
4. Finally, we use the selective_log_softmax to compute log probabilities only for those tokens.
56+
"""
57+
# Run the model forward pass and obtain logits.
58+
logits = model(
59+
input_ids=input_ids,
60+
attention_mask=attention_mask,
61+
logits_to_keep=logits_to_keep + 1, # Request one extra logit for proper alignment.
62+
).logits # Shape: (batch_size, total_seq_len, vocab_size)
63+
64+
# Remove the last logit as it does not have a corresponding target token.
65+
logits = logits[:, :-1, :] # New shape: (batch_size, total_seq_len - 1, vocab_size)
66+
67+
# Slice the input_ids to keep only the last logits_to_keep tokens.
68+
# This corresponds to the generated completion tokens.
69+
input_ids = input_ids[:, -logits_to_keep:] # Shape: (batch_size, logits_to_keep)
70+
71+
# Also slice the logits to keep only those corresponding to the completion tokens.
72+
logits = logits[:, -logits_to_keep:, :] # Shape: (batch_size, logits_to_keep, vocab_size)
73+
74+
# Compute and return the log probabilities for the selected tokens.
75+
return selective_log_softmax(logits, input_ids)
76+
77+
78+
def create_completion_mask(completion_ids, eos_token_id):
79+
"""Create a binary mask for the generated completion tokens so that tokens after the first EOS are ignored.
80+
81+
Args:
82+
completion_ids (torch.Tensor): Tensor of shape (batch_size, seq_len) with generated token ids.
83+
eos_token_id (int): The token id representing the end-of-sequence.
84+
85+
Returns:
86+
torch.Tensor: A mask tensor of shape (batch_size, seq_len) with 1s for tokens up to and including the first EOS
87+
and 0s for tokens following the first EOS.
88+
89+
Explanation:
90+
1. First, a boolean mask (is_eos) is created indicating where in the sequence the EOS token appears.
91+
2. An index tensor (eos_idx) is initialized, assuming that no EOS is found (defaulting to the sequence length).
92+
3. For sequences where EOS exists, eos_idx is updated to the position (index) of the first EOS.
93+
4. A sequence index tensor is created that contains indices for each position in the sequence.
94+
5. The final mask is computed by comparing the sequence indices to eos_idx (after adding a dimension).
95+
"""
96+
# Determine which positions in each sequence equal the EOS token.
97+
is_eos = completion_ids == eos_token_id # Boolean tensor of shape (batch_size, seq_len)
98+
99+
# Initialize a tensor to store the index of the first EOS for each sequence.
100+
# If no EOS is found, default to the full sequence length (is_eos.size(1)).
101+
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device)
102+
103+
# Identify sequences that contain at least one EOS.
104+
mask_exists = is_eos.any(dim=1)
105+
# For sequences with an EOS, update eos_idx to the index of the first occurrence.
106+
eos_idx[mask_exists] = is_eos.int().argmax(dim=1)[mask_exists]
107+
108+
# Create a tensor of indices [0, 1, 2, ..., seq_len-1] and replicate it for each sequence in the batch.
109+
sequence_indices = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1)
110+
111+
# Build the mask: positions with an index less than or equal to the first EOS index are marked as 1.
112+
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
113+
114+
return completion_mask
115+
116+
117+
def generate_completions(model, tokenizer, prompts, num_generations=4, max_completion_length=32):
118+
"""Generate multiple completions for each prompt and create corresponding attention masks.
119+
120+
Args:
121+
model: The language model used for generation.
122+
tokenizer: The tokenizer to process the prompts and decode the outputs.
123+
prompts (list of str): List of input prompt strings.
124+
num_generations (int): Number of completions to generate per prompt.
125+
max_completion_length (int): Maximum number of new tokens to generate for the completion.
126+
127+
Returns:
128+
tuple: Contains the following tensors:
129+
- prompt_ids: (batch_size * num_generations, prompt_seq_len)
130+
- prompt_mask: (batch_size * num_generations, prompt_seq_len)
131+
- completion_ids: (batch_size * num_generations, completion_seq_len)
132+
- completion_mask: (batch_size * num_generations, completion_seq_len)
133+
134+
Explanation:
135+
1. The prompts are tokenized and padded (with padding added to the left).
136+
2. Each prompt is repeated num_generations times so that multiple completions are generated per prompt.
137+
3. The model.generate() function is called to generate new tokens.
138+
4. The generated output contains the prompt followed by the completion; we remove the prompt part to get the completions.
139+
5. A mask is created (via create_completion_mask) so that only tokens up to the first EOS are considered.
140+
"""
141+
device = next(model.parameters()).device
142+
143+
# Tokenize the list of prompts with padding. The padding_side="left" ensures alignment on the right.
144+
tokenizer.padding_side = "left"
145+
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
146+
prompt_ids = inputs["input_ids"].to(device) # Shape: (batch_size, prompt_seq_len)
147+
prompt_mask = inputs["attention_mask"].to(device) # Shape: (batch_size, prompt_seq_len)
148+
prompt_length = prompt_ids.size(1) # Save the prompt length to later separate prompt from completion.
149+
150+
# Repeat each prompt num_generations times.
151+
prompt_ids = prompt_ids.repeat_interleave(
152+
num_generations, dim=0
153+
) # New shape: (batch_size*num_generations, prompt_seq_len)
154+
prompt_mask = prompt_mask.repeat_interleave(
155+
num_generations, dim=0
156+
) # New shape: (batch_size*num_generations, prompt_seq_len)
157+
158+
# Generate new tokens for each prompt. The output includes the original prompt and the generated tokens.
159+
outputs = model.generate(
160+
prompt_ids,
161+
attention_mask=prompt_mask,
162+
max_new_tokens=max_completion_length,
163+
do_sample=True,
164+
temperature=1.0,
165+
pad_token_id=tokenizer.pad_token_id,
166+
eos_token_id=tokenizer.eos_token_id,
167+
)
168+
169+
# Remove the prompt portion from the generated output to isolate the completion tokens.
170+
completion_ids = outputs[:, prompt_length:] # Shape: (batch_size*num_generations, completion_seq_len)
171+
172+
# Create a binary mask that ignores tokens beyond the first EOS token.
173+
completion_mask = create_completion_mask(completion_ids, tokenizer.eos_token_id)
174+
175+
return prompt_ids, prompt_mask, completion_ids, completion_mask
176+
177+
178+
def generate_rollout_data(model, ref_model, tokenizer, batch_samples, num_generations, max_completion_length):
179+
"""Generate rollouts and compute static log probabilities for both the old policy (current model)
180+
and the reference model. Gradients are disabled so that these remain fixed.
181+
182+
Args:
183+
model: The current model (policy) used to generate rollouts.
184+
ref_model: The static reference model.
185+
tokenizer: The tokenizer.
186+
batch_samples: List of training samples.
187+
num_generations: Number of completions to generate per prompt.
188+
max_completion_length: Maximum completion length.
189+
190+
Returns:
191+
A dictionary with rollout data including both old and reference log probabilities.
192+
"""
193+
tokenizer.padding_side = "left"
194+
next(model.parameters()).device
195+
196+
# Extract prompts and answers.
197+
prompts = [sample["prompt"] if isinstance(sample, dict) else sample[0] for sample in batch_samples]
198+
answers = [sample["answer"] if isinstance(sample, dict) else sample[1] for sample in batch_samples]
199+
200+
# Generate completions and associated masks.
201+
# We generate once, and then use the same completions to compute both sets of log probabilities.
202+
with torch.no_grad():
203+
prompt_ids, prompt_mask, completion_ids, completion_mask = generate_completions(
204+
model, tokenizer, prompts, num_generations, max_completion_length
205+
)
206+
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
207+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
208+
logits_to_keep = completion_ids.size(1)
209+
210+
# Compute old_log_probs from the current model, with gradients disabled.
211+
old_log_probs = compute_log_probabilities(model, input_ids, attention_mask, logits_to_keep)
212+
213+
# Compute ref_log_probs from the reference model, which remains static.
214+
ref_log_probs = compute_log_probabilities(ref_model, input_ids, attention_mask, logits_to_keep)
215+
216+
formatted_completions = [[{"content": tokenizer.decode(ids, skip_special_tokens=True)}] for ids in completion_ids]
217+
repeated_prompts = [p for p in prompts for _ in range(num_generations)]
218+
repeated_answers = [a for a in answers for _ in range(num_generations)]
219+
220+
return {
221+
"input_ids": input_ids,
222+
"attention_mask": attention_mask,
223+
"completion_mask": completion_mask,
224+
"old_log_probs": old_log_probs, # Static log probs from the current model (old policy)
225+
"ref_log_probs": ref_log_probs, # Static log probs from the reference model
226+
"formatted_completions": formatted_completions,
227+
"repeated_prompts": repeated_prompts,
228+
"repeated_answers": repeated_answers,
229+
"logits_to_keep": logits_to_keep,
230+
"batch_size": len(prompts),
231+
"num_generations": num_generations,
232+
}

0 commit comments

Comments
 (0)