Skip to content

Commit 702f244

Browse files
committed
feature(nyz): add rlhf dataset
1 parent bf258f8 commit 702f244

File tree

5 files changed

+422
-0
lines changed

5 files changed

+422
-0
lines changed
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
from typing import Iterable, Dict, List, Union, Any, Callable
2+
from functools import partial
3+
from tqdm import tqdm
4+
from torch.utils.data import Dataset
5+
from torch.distributed import get_rank
6+
import torch
7+
import torch.nn.functional as F
8+
9+
10+
def zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left", value: int = 0) -> torch.Tensor:
11+
assert side in ("left", "right")
12+
max_len = max(seq.size(-1) for seq in sequences)
13+
padded_sequences = []
14+
for seq in sequences:
15+
pad_len = max_len - seq.size(-1)
16+
padding = (pad_len, 0) if side == "left" else (0, pad_len)
17+
padded_sequences.append(F.pad(seq, padding, value=value))
18+
return torch.stack(padded_sequences, dim=0)
19+
20+
21+
class OfflineRLDataset(Dataset):
22+
"""
23+
Overview:
24+
PyTorch Dataset for OfflineRL LLM training like KTO and DPO.
25+
"""
26+
27+
def __init__(
28+
self,
29+
dataset: Iterable[Dict],
30+
tokenizer,
31+
max_length: int,
32+
input_key: str = "input",
33+
output_key: str = "output",
34+
label_key: str = "label",
35+
apply_chat_template: bool = False,
36+
tokenizer_chat_template: str = None,
37+
input_template: str = None,
38+
num_processors: int = 8,
39+
parallel_load: bool = True
40+
) -> None:
41+
super().__init__()
42+
self.tokenizer = tokenizer
43+
self.max_length = max_length
44+
45+
if apply_chat_template:
46+
apply_chat_template = self.tokenizer.apply_chat_template
47+
if tokenizer_chat_template:
48+
self.tokenizer.chat_template = tokenizer_chat_template
49+
50+
# Parallel loading datasets
51+
if parallel_load:
52+
preprocess_data_fn = partial(
53+
self._preprocess_data,
54+
input_template=input_template,
55+
input_key=input_key,
56+
output_key=output_key,
57+
label_key=label_key,
58+
apply_chat_template=apply_chat_template
59+
)
60+
processed_dataset = dataset.map(
61+
preprocess_data_fn, remove_columns=dataset.column_names, num_proc=num_processors
62+
)
63+
# preprocess function may return None, so filter out the None
64+
processed_dataset = processed_dataset.filter(lambda x: x["prompt"] is not None)
65+
66+
self.prompts = processed_dataset["prompt"]
67+
self.responses = processed_dataset["response"]
68+
self.labels = processed_dataset["label"]
69+
self.prompt_ids_lens = processed_dataset["prompt_ids_len"]
70+
else:
71+
self.prompts = []
72+
self.responses = []
73+
self.labels = []
74+
self.prompt_ids_lens = []
75+
for data in tqdm(dataset, desc="Preprocessing data", disable=not get_rank() == 0):
76+
processed_data = self._preprocess_data(data)
77+
if processed_data["prompt"] is not None:
78+
self.prompts.append(processed_data["prompt"])
79+
self.responses.append(processed_data["response"])
80+
self.labels.append(processed_data["label"])
81+
self.prompt_ids_lens.append(processed_data["prompt_ids_len"])
82+
83+
def _preprocess_data(
84+
self,
85+
data: Dict[str, Any],
86+
input_template: str = None,
87+
input_key: str = "input",
88+
output_key: str = "output",
89+
label_key: str = "label",
90+
apply_chat_template: Union[bool, Callable] = False,
91+
) -> str:
92+
label = data[label_key]
93+
94+
if apply_chat_template:
95+
if output_key:
96+
prompt = apply_chat_template(data[input_key], tokenize=False, add_generation_prompt=True)
97+
response = apply_chat_template(data[input_key] + data[output_key], tokenize=False)[len(prompt):]
98+
else:
99+
prompt = apply_chat_template(data[input_key][:-1], tokenize=False, add_generation_prompt=True)
100+
response = apply_chat_template(data[input_key], tokenize=False)[len(prompt):]
101+
else:
102+
prompt = data[input_key]
103+
response = data[output_key]
104+
if input_template:
105+
prompt = input_template.format(prompt)
106+
107+
prompt_token = self.tokenizer(
108+
prompt,
109+
max_length=self.max_length,
110+
# use the batch max length (in `collate_fn`) to pad rather than the global max length
111+
padding=False,
112+
truncation=True,
113+
return_tensors="pt",
114+
# add special tokens for the prompt in `collate_fn`
115+
add_special_tokens=False,
116+
)
117+
prompt_ids_len = prompt_token["attention_mask"].int().sum().item()
118+
119+
# filter the sample whose length is greater than max_length (2 for answer length)
120+
if prompt_ids_len >= self.max_length - 2:
121+
prompt = None
122+
123+
return {"prompt": prompt, "response": response, "label": label, "prompt_ids_len": prompt_ids_len}
124+
125+
def __len__(self) -> int:
126+
"""
127+
Overview:
128+
Get the length of the dataset.
129+
Returns:
130+
- length (int): The length of the dataset.
131+
"""
132+
return len(self.prompts)
133+
134+
def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, int]]:
135+
"""
136+
Overview:
137+
Get the item at the given index.
138+
Returns:
139+
- item (Dict[str, Union[torch.Tensor, int]]): The item at the given index.
140+
"""
141+
return {
142+
"prompt": self.prompts[idx],
143+
"response": self.responses[idx],
144+
"label": self.labels[idx],
145+
"prompt_ids_len": self.prompt_ids_lens[idx]
146+
}
147+
148+
def collate_fn(self, item_list: List[Dict[str, Union[torch.Tensor, int]]]):
149+
150+
def tokenizer(prompt: str, response: str):
151+
text = (prompt + response).rstrip("\n")
152+
if not text.endswith(self.tokenizer.eos_token):
153+
text += " " + self.tokenizer.eos_token
154+
inputs = self.tokenizer(
155+
text,
156+
max_length=self.max_length,
157+
padding=False,
158+
truncation=True,
159+
return_tensors="pt",
160+
add_special_tokens=False,
161+
)
162+
163+
inputs["input_ids"][0][-1] = self.tokenizer.eos_token_id
164+
inputs["attention_mask"][0][-1] = True
165+
return inputs["input_ids"], inputs["attention_mask"]
166+
167+
tot_ids, tot_masks, tot_labels, prompt_ids_lens = [], [], [], []
168+
for item in item_list:
169+
input_ids, attention_mask = tokenizer(item["prompt"], item["response"])
170+
tot_ids.append(input_ids)
171+
tot_masks.append(attention_mask)
172+
tot_labels.append(item["label"])
173+
prompt_ids_lens.append(item["prompt_ids_len"])
174+
175+
# add unmatched y'| x (used to estimate the KL divergence between policy and reference)
176+
for idx in range(len(item_list)):
177+
next_idx = (idx + 1) % len(item_list)
178+
input_ids, attention_mask = tokenizer(item_list[idx]["prompt"], item_list[next_idx]["response"])
179+
tot_ids.append(input_ids)
180+
tot_masks.append(attention_mask)
181+
tot_labels.append(-1)
182+
prompt_ids_lens.append(item_list[idx]["prompt_ids_len"])
183+
184+
input_ids = zero_pad_sequences(tot_ids, side="right", value=self.tokenizer.pad_token_id)
185+
attention_mask = zero_pad_sequences(tot_masks, side="right")
186+
return input_ids, attention_mask, torch.LongTensor(tot_labels), prompt_ids_lens
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from typing import Any, Dict, Union, Callable, Iterable
2+
from tqdm import tqdm
3+
from torch.utils.data import Dataset
4+
from torch.distributed import get_rank
5+
6+
7+
class OnlineRLDataset(Dataset):
8+
"""
9+
Overview:
10+
PyTorch Dataset for OnlineRL LLM training like PPO.
11+
"""
12+
13+
def __init__(
14+
self,
15+
dataset: Iterable[Dict],
16+
tokenizer,
17+
input_key: str = "input",
18+
apply_chat_template: bool = False,
19+
input_template: str = None,
20+
) -> None:
21+
"""
22+
Overview:
23+
Initialize the OnlineRLDataset.
24+
Arguments:
25+
- dataset (torch.utils.data.Dataset): The dataset to preprocess.
26+
- tokenizer (): The tokenizer to preprocess the data.
27+
- input_key (str): The key of the input data.
28+
- apply_chat_template (bool): Whether to apply the chat template.
29+
- input_template (str): The template to format the data.
30+
"""
31+
super().__init__()
32+
self.tokenizer = tokenizer
33+
self.input_template = input_template
34+
35+
if apply_chat_template:
36+
apply_chat_template = self.tokenizer.apply_chat_template
37+
38+
self.prompts = []
39+
try:
40+
rank = get_rank()
41+
except ValueError: # not initialized yet, which is the case in unit test
42+
rank = 0
43+
for data in tqdm(dataset, desc="Preprocessing data", disable=not rank == 0):
44+
prompt = self._preprocess_data(data, input_template, input_key, apply_chat_template)
45+
self.prompts.append(prompt)
46+
47+
def __len__(self) -> int:
48+
"""
49+
Overview:
50+
Get the length of the dataset.
51+
Returns:
52+
- length (int): The length of the dataset.
53+
"""
54+
return len(self.prompts)
55+
56+
def __getitem__(self, idx: int) -> str:
57+
"""
58+
Overview:
59+
Get the item at the given index.
60+
Args:
61+
- idx (int): The index of the item to get.
62+
Returns:
63+
- item (str): The item at the given index.
64+
"""
65+
return self.prompts[idx]
66+
67+
def _preprocess_data(
68+
self,
69+
data: Dict[str, Any],
70+
input_template: str = None,
71+
input_key: str = "input",
72+
apply_chat_template: Union[bool, Callable] = False,
73+
) -> str:
74+
"""
75+
Overview:
76+
Preprocess the data to get the formatted prompt.
77+
Arguments:
78+
- data (Dict[str, Any]): The data to preprocess.
79+
- input_template (str): The template to format the data.
80+
- input_key (str): The key of the input data.
81+
- apply_chat_template (Union[bool, Callable]): The function to apply the chat template, \
82+
usually is the `tokenizer.apply_chat_template`.
83+
Returns:
84+
- prompt (str): The formatted prompt.
85+
"""
86+
if apply_chat_template:
87+
chat = data[input_key]
88+
if isinstance(chat, str):
89+
chat = [{"role": "user", "content": chat}]
90+
assert isinstance(chat, list) and all(isinstance(t, dict) for t in chat), "chat must be a list of dict"
91+
prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
92+
else:
93+
prompt = data[input_key]
94+
if input_template:
95+
prompt = input_template.format(prompt)
96+
return prompt
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import pytest
2+
from datasets import load_dataset, concatenate_datasets
3+
from rl.data.offlinerl_dataset import OfflineRLDataset
4+
from transformers import AutoTokenizer
5+
6+
7+
@pytest.fixture
8+
def dataset():
9+
# Load a sample dataset
10+
hf_dataset = load_dataset("MMInstruction/VL-RewardBench", split='test')
11+
# split pair data into two separate datasets
12+
hf_dataset_1 = hf_dataset.map(
13+
lambda x: {
14+
"prompt": x["query"],
15+
"response": x["response"][0],
16+
'human_ranking': x["human_ranking"][0]
17+
}
18+
)
19+
hf_dataset_2 = hf_dataset.map(
20+
lambda x: {
21+
"prompt": x["query"],
22+
"response": x["response"][1],
23+
'human_ranking': x["human_ranking"][0]
24+
}
25+
)
26+
# combine two datasets
27+
hf_dataset = concatenate_datasets([hf_dataset_1, hf_dataset_2])
28+
# shuffle the dataset
29+
hf_dataset = hf_dataset.shuffle(seed=42)
30+
return hf_dataset
31+
32+
33+
@pytest.fixture
34+
def tokenizer():
35+
# Load a tokenizer
36+
return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B")
37+
38+
39+
@pytest.mark.unittest
40+
def test_offline_rl_dataset_initialization(dataset, tokenizer):
41+
# Test the initialization of the OfflineRLDataset
42+
offline_dataset = OfflineRLDataset(
43+
dataset=dataset,
44+
tokenizer=tokenizer,
45+
max_length=1024,
46+
input_key="query",
47+
output_key="response",
48+
label_key="human_ranking"
49+
)
50+
assert len(offline_dataset) == len(dataset)
51+
offline_dataset = OfflineRLDataset(
52+
dataset=dataset,
53+
tokenizer=tokenizer,
54+
max_length=256,
55+
input_key="query",
56+
output_key="response",
57+
label_key="human_ranking"
58+
)
59+
# lower max_length will filter out some samples
60+
assert len(offline_dataset) < len(dataset)
61+
62+
63+
@pytest.mark.unittest
64+
def test_offline_rl_dataset_item_retrieval(dataset, tokenizer):
65+
# Test retrieving an item from the OfflineRLDataset
66+
offline_dataset = OfflineRLDataset(
67+
dataset=dataset,
68+
tokenizer=tokenizer,
69+
max_length=256,
70+
input_key="query",
71+
output_key="response",
72+
label_key="human_ranking"
73+
)
74+
item = offline_dataset[0]
75+
assert "prompt" in item
76+
assert "response" in item
77+
assert "label" in item
78+
assert "prompt_ids_len" in item
79+
print(item)
80+
81+
82+
@pytest.mark.unittest
83+
def test_offline_rl_dataset_collate_fn(dataset, tokenizer):
84+
# Test the collate function of the OfflineRLDataset
85+
offline_dataset = OfflineRLDataset(
86+
dataset=dataset,
87+
tokenizer=tokenizer,
88+
max_length=256,
89+
input_key="query",
90+
output_key="response",
91+
label_key="human_ranking"
92+
)
93+
B = 10
94+
item_list = [offline_dataset[i] for i in range(B)]
95+
input_ids, attention_mask, labels, prompt_ids_lens = offline_dataset.collate_fn(item_list)
96+
assert input_ids.size(0) == len(item_list) * 2 # because of the unmatched y'| x
97+
assert attention_mask.size(0) == len(item_list) * 2
98+
assert labels.size(0) == len(item_list) * 2
99+
assert len(prompt_ids_lens) == len(item_list) * 2

0 commit comments

Comments
 (0)