Skip to content

Commit 791ebdc

Browse files
XieLipeng0830ployts
authored andcommitted
feat(cookbooks): add reward model training recipes for Bradley-Terry and SFT
* feat(cookbooks): add reward model training recipes for Bradley-Terry and SFT * rename: bradley-terry package * fix: use relative import for BTDataset in trainer.py * docs: use "judge model" instead of "reward model" in training cookbooks
1 parent ac0a508 commit 791ebdc

File tree

9 files changed

+1470
-0
lines changed

9 files changed

+1470
-0
lines changed
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Bradley-Terry Training
2+
3+
Train judge models using Bradley-Terry loss on preference pairs. This approach learns to rank responses by modeling the probability that one response is preferred over another.
4+
5+
## Overview
6+
7+
Bradley-Terry training is the **simplest and most widely used** method for judge model training. It works with binary preference data (chosen vs. rejected) and optimizes the model to predict which response humans prefer.
8+
9+
> **Tip:** Use Bradley-Terry when you have preference pairs (e.g., from RLHF annotation), binary comparison data, or need a model that outputs scalar scores.
10+
11+
**Training objective:**
12+
13+
The model learns to maximize:
14+
15+
$$\mathcal{L} = -\log \sigma(r_{\text{chosen}} - r_{\text{rejected}})$$
16+
17+
Where $r$ is the score and $\sigma$ is the sigmoid function.
18+
19+
20+
## Quick Start
21+
22+
```bash
23+
# 1. Install dependencies
24+
pip install verl==0.6.1
25+
26+
# 2. Run training
27+
cd cookbooks/training_judge_model/bradley-terry
28+
bash run_bt_rm.sh
29+
```
30+
31+
32+
## Dataset
33+
34+
We provide pre-processed datasets on HuggingFace:
35+
36+
| Dataset | Description | Link |
37+
|---------|-------------|------|
38+
| `agentscope-ai/OpenJudge` | HelpSteer2 preference pairs for BT training | [🔗 HuggingFace](https://huggingface.co/datasets/agentscope-ai/OpenJudge/tree/main/train_rm/bradley_terry) |
39+
40+
**Source:** [nvidia/HelpSteer2](https://huggingface.co/datasets/nvidia/HelpSteer2) preference subset.
41+
42+
**Processing:**
43+
- Input: HelpSteer2 preference JSONL with `preference_strength` field (range: -3 to 3)
44+
- Filter: `|preference_strength| >= 1` (keep pairs with clear preference)
45+
- Positive strength → `response_2` is chosen; Negative → `response_1` is chosen
46+
- Convert to chat messages format for Instruct models
47+
48+
49+
## Data Format
50+
51+
Bradley-Terry training expects Parquet files with two columns:
52+
53+
| Column | Type | Description |
54+
|--------|------|-------------|
55+
| `chosen` | string | JSON string of messages list (preferred response) |
56+
| `rejected` | string | JSON string of messages list (rejected response) |
57+
58+
**Example data structure:**
59+
60+
```python
61+
import json
62+
import pandas as pd
63+
64+
# Messages format (compatible with tokenizer.apply_chat_template)
65+
chosen = json.dumps([
66+
{"role": "user", "content": "What are the benefits of exercise?"},
67+
{"role": "assistant", "content": "Regular exercise improves cardiovascular health, boosts mood, and increases energy levels."}
68+
])
69+
rejected = json.dumps([
70+
{"role": "user", "content": "What are the benefits of exercise?"},
71+
{"role": "assistant", "content": "Exercise is good for you."}
72+
])
73+
74+
df = pd.DataFrame({"chosen": [chosen], "rejected": [rejected]})
75+
df.to_parquet("train.parquet")
76+
```
77+
78+
> **Note:** Multi-turn conversations are supported. Include all turns in the messages list.
79+
80+
81+
## Configuration
82+
83+
### Training Script (`run_bt_rm.sh`)
84+
85+
Key parameters to customize:
86+
87+
| Parameter | Description | Default |
88+
|-----------|-------------|---------|
89+
| `MODEL_PATH` | Base model for initialization | `qwen3-32b` |
90+
| `TRAIN_FILE` | Training data path | Parquet file |
91+
| `VAL_FILE` | Validation data path | Parquet file |
92+
| `TRAIN_BATCH_SIZE` | Global batch size | 256 |
93+
| `MICRO_BATCH_SIZE` | Per-GPU micro batch | 1 |
94+
| `LR` | Learning rate | 5e-7 |
95+
| `TOTAL_EPOCHS` | Training epochs | 3 |
96+
97+
### Hydra Config (`trainer.yaml`)
98+
99+
**Data:**
100+
101+
```yaml
102+
data:
103+
train_batch_size: 256 # Global batch size (distributed across GPUs)
104+
micro_batch_size_per_gpu: 1 # Per-GPU micro batch size
105+
max_length: 4096 # Maximum sequence length
106+
truncation: left # Truncation: left/right/error
107+
```
108+
109+
**Model:**
110+
111+
```yaml
112+
model:
113+
partial_pretrain: qwen3-32b # Base model path
114+
strategy: fsdp2 # fsdp or fsdp2
115+
enable_gradient_checkpointing: true # Save memory
116+
```
117+
118+
**Optimizer:**
119+
120+
```yaml
121+
optim:
122+
lr: 5e-7 # Learning rate
123+
weight_decay: 0.001 # Weight decay
124+
warmup_steps_ratio: 0.03 # Warmup steps ratio
125+
clip_grad: 2.0 # Gradient clipping
126+
lr_scheduler: cosine # Scheduler: cosine/wsd/constant
127+
```
128+
129+
130+
## Monitoring Training
131+
132+
### Metrics
133+
134+
| Metric | Description |
135+
|--------|-------------|
136+
| `train/loss` | Bradley-Terry loss |
137+
| `train/accuracy` | Preference prediction accuracy (chosen > rejected) |
138+
| `train/lr(1e-3)` | Current learning rate (×1e3) |
139+
| `val/loss` | Validation loss |
140+
| `val/accuracy` | Validation accuracy |
141+
142+
### Train/Loss Curve
143+
144+
![BT Training Curve](./bt_train.png)
145+
146+
147+
## Troubleshooting
148+
149+
### OOM (Out of Memory)
150+
151+
- Reduce `micro_batch_size_per_gpu`
152+
- Enable `enable_gradient_checkpointing`
153+
- Reduce `max_length`
154+
- Use `fsdp2` strategy
155+
156+
### Unstable Training / Loss Explosion
157+
158+
- Lower learning rate
159+
- Increase `clip_grad` value
160+
- Check data quality
161+
162+
### Accuracy Not Improving
163+
164+
- Verify data labeling quality
165+
- Check chosen/rejected mapping
166+
- Increase learning rate
167+
- Train more epochs
168+
169+
170+
## Next Steps
171+
172+
- [SFT for Judge Models](../sft/README.md) — Pre-train with supervised fine-tuning
94.4 KB
Loading
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Bradley-Terry Dataset for Judge Model Training
4+
- Loads preference data from parquet files
5+
- Each sample contains chosen and rejected responses
6+
- Returns data in format suitable for Bradley-Terry loss
7+
- Uses chat template for Instruct models
8+
"""
9+
10+
import json
11+
from typing import Any, Dict, List, Union
12+
13+
import pandas as pd
14+
import torch
15+
from torch.utils.data import Dataset
16+
from transformers import PreTrainedTokenizer
17+
from verl.utils import hf_tokenizer
18+
from verl.utils.fs import copy_to_local
19+
20+
21+
class BTDataset(Dataset):
22+
"""
23+
Bradley-Terry Dataset for preference learning
24+
25+
Expected data format in parquet:
26+
- chosen: messages list [{"role": "user", "content": "xxx"}, {"role": "assistant", "content": "yyy"}]
27+
- rejected: messages list [{"role": "user", "content": "xxx"}, {"role": "assistant", "content": "yyy"}]
28+
29+
Data is processed with tokenizer.apply_chat_template() for Instruct models.
30+
"""
31+
32+
def __init__(
33+
self,
34+
parquet_files: Union[str, List[str]],
35+
tokenizer: Union[str, PreTrainedTokenizer],
36+
config: Dict[str, Any],
37+
) -> None:
38+
self.max_length = config.get("max_length", 4096)
39+
self.truncation = config.get("truncation", "left")
40+
self.use_shm = config.get("use_shm", False)
41+
42+
# Keys for data columns
43+
self.chosen_key = config.get("chosen_key", "chosen")
44+
self.rejected_key = config.get("rejected_key", "rejected")
45+
46+
assert self.truncation in ["error", "left", "right"]
47+
48+
if not isinstance(parquet_files, list):
49+
parquet_files = [parquet_files]
50+
51+
self.parquet_files = parquet_files
52+
if isinstance(tokenizer, str):
53+
tokenizer = hf_tokenizer(tokenizer)
54+
self.tokenizer: PreTrainedTokenizer = tokenizer
55+
56+
self._download()
57+
self._read_files_and_process()
58+
59+
def _download(self) -> None:
60+
"""Download parquet files to local if needed"""
61+
for i, parquet_file in enumerate(self.parquet_files):
62+
self.parquet_files[i] = copy_to_local(parquet_file, verbose=True)
63+
64+
def _read_files_and_process(self) -> None:
65+
"""Read and concatenate all parquet files"""
66+
dataframes = []
67+
for parquet_file in self.parquet_files:
68+
dataframe = pd.read_parquet(parquet_file)
69+
dataframes.append(dataframe)
70+
71+
self.dataframe = pd.concat(dataframes, ignore_index=True)
72+
73+
# Extract chosen and rejected fields (JSON string format, parse to messages list)
74+
self.chosen_messages = [json.loads(msg) for msg in self.dataframe[self.chosen_key].tolist()]
75+
self.rejected_messages = [json.loads(msg) for msg in self.dataframe[self.rejected_key].tolist()]
76+
77+
print(
78+
f"Loaded {len(self.chosen_messages)} preference pairs from {len(self.parquet_files)} files",
79+
)
80+
81+
def __len__(self) -> int:
82+
return len(self.chosen_messages)
83+
84+
def _apply_chat_template(self, messages: List[Dict[str, str]]) -> str:
85+
"""
86+
Apply chat template to convert messages to model-expected format.
87+
88+
Args:
89+
messages: List of message dicts [{"role": "user", "content": "..."}, ...]
90+
"""
91+
formatted = self.tokenizer.apply_chat_template(
92+
messages,
93+
tokenize=False,
94+
add_generation_prompt=False,
95+
)
96+
# Remove BOS token if present
97+
if self.tokenizer.bos_token and formatted.startswith(self.tokenizer.bos_token):
98+
formatted = formatted[len(self.tokenizer.bos_token) :]
99+
return formatted
100+
101+
def _tokenize_messages(self, messages: List[Dict[str, str]]) -> Dict[str, torch.Tensor]:
102+
"""Tokenize messages and handle truncation/padding to fixed length"""
103+
# Apply chat template
104+
text = self._apply_chat_template(messages)
105+
106+
# Tokenize
107+
encoding = self.tokenizer(
108+
text,
109+
add_special_tokens=True,
110+
return_tensors="pt",
111+
padding=False,
112+
truncation=False,
113+
)
114+
115+
input_ids = encoding["input_ids"].squeeze(0)
116+
attention_mask = encoding["attention_mask"].squeeze(0)
117+
118+
sequence_length = input_ids.shape[0]
119+
120+
# Handle sequence length like SFT dataset
121+
if sequence_length < self.max_length:
122+
# Pad sequences
123+
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0
124+
padded_input_ids = (
125+
torch.ones(
126+
size=(self.max_length - sequence_length,),
127+
dtype=input_ids.dtype,
128+
)
129+
* pad_token_id
130+
)
131+
padded_attention_mask = torch.zeros(
132+
size=(self.max_length - sequence_length,),
133+
dtype=attention_mask.dtype,
134+
)
135+
136+
input_ids = torch.cat((input_ids, padded_input_ids))
137+
attention_mask = torch.cat((attention_mask, padded_attention_mask))
138+
elif sequence_length > self.max_length:
139+
if self.truncation == "left":
140+
# Keep the end of the conversation (including conclusion)
141+
input_ids = input_ids[-self.max_length :]
142+
attention_mask = attention_mask[-self.max_length :]
143+
elif self.truncation == "right":
144+
input_ids = input_ids[: self.max_length]
145+
attention_mask = attention_mask[: self.max_length]
146+
elif self.truncation == "error":
147+
raise ValueError(
148+
f"Sequence length {sequence_length} > max_length {self.max_length}",
149+
)
150+
151+
return {"input_ids": input_ids, "attention_mask": attention_mask}
152+
153+
def __getitem__(self, item: int) -> Dict[str, Any]:
154+
"""
155+
Get a preference pair
156+
157+
Returns:
158+
dict with keys:
159+
- input_ids_j: chosen response tokens
160+
- attention_mask_j: chosen response attention mask
161+
- input_ids_k: rejected response tokens
162+
- attention_mask_k: rejected response attention mask
163+
"""
164+
chosen_messages = self.chosen_messages[item]
165+
rejected_messages = self.rejected_messages[item]
166+
167+
# Tokenize both responses
168+
chosen_tokens = self._tokenize_messages(chosen_messages)
169+
rejected_tokens = self._tokenize_messages(rejected_messages)
170+
171+
return {
172+
"input_ids_j": chosen_tokens["input_ids"],
173+
"attention_mask_j": chosen_tokens["attention_mask"],
174+
"input_ids_k": rejected_tokens["input_ids"],
175+
"attention_mask_k": rejected_tokens["attention_mask"],
176+
}

0 commit comments

Comments
 (0)