Skip to content

Commit 729e21f

Browse files
authored
Merge pull request #5 from explodinggradients/dev
Fix github workflows
2 parents 97975b9 + 7465fcc commit 729e21f

File tree

10 files changed

+135
-62
lines changed

10 files changed

+135
-62
lines changed

.github/workflows/ci.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ jobs:
3838
sudo apt-get install libsndfile1
3939
pip install -r requirements.txt
4040
pip install black
41-
- name: Install reward-model
42-
run: |
43-
pip install -e .[dev]
41+
# - name: Install reward-model
42+
# run: |
43+
# pip install -e .
4444
- name: Run black
45-
# run:
46-
# black --check . --exclude blade2blade/__init__.py
45+
run:
46+
black --check .

.pre-commit-config.yaml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
repos:
2+
- repo: https://github.com/ambv/black
3+
rev: 22.8.0
4+
hooks:
5+
- id: black
6+
7+
# Sort imports
8+
- repo: https://github.com/PyCQA/isort
9+
rev: 5.10.1
10+
hooks:
11+
- id: isort
12+
args: ["--profile", "black"]
13+
14+
- repo: https://gitlab.com/pycqa/flake8
15+
rev: 5.0.4
16+
hooks:
17+
- id: flake8
18+
args: ['--ignore=E203,E501,F811,E712,W503']
19+
exclude: __init__.py
20+
21+
# Formatting, Whitespace, etc
22+
- repo: https://github.com/pre-commit/pre-commit-hooks
23+
rev: v3.2.0
24+
hooks:
25+
- id: trailing-whitespace
26+
- id: check-added-large-files
27+
args: ['--maxkb=1000']
28+
- id: check-ast
29+
- id: check-json
30+
- id: check-merge-conflict
31+
- id: check-xml
32+
- id: check-yaml
33+
- id: debug-statements
34+
- id: end-of-file-fixer
35+
- id: requirements-txt-fixer
36+
- id: mixed-line-ending
37+
args: ['--fix=no']

README.md

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,19 @@
11
# Reward-Model
2-
Framework for reward model for RLHF.
2+
Framework for reward model for RLHF.
3+
4+
5+
### Quick Start
6+
* Inference
7+
```python
8+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
9+
MODEL = ""
10+
11+
model = AutoModelForSequenceClassification.from_pretrained(MODEL)
12+
tokenizer = AutoTokenizer.from_pretrained(MODEL)
13+
14+
```
15+
16+
* Training
17+
```bash
18+
python src/training.py --config-name <your-config-name>
19+
```

pyproject.toml

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
[project]
2-
name = "blade2blade"
3-
description = "Adversarial Training and SFT for Bot Safety Models"
2+
name = "Reward-Model"
3+
description = "Reward Model training for LLM alignment"
44
version = "0.0.1"
55
authors = [
6-
{ name = "LAION-AI", email = "[email protected]" }
6+
{ name = "Exploding gradients", email = "[email protected]" }
77
]
88
readme = "README.md"
99
dependencies = [
@@ -12,11 +12,8 @@ dependencies = [
1212
"transformers>=4.27.4",
1313
"hydra-core>=1.3.2",
1414
"tokenizers>=0.13.2",
15-
]
16-
17-
[project.optional-dependencies]
18-
dev = [
1915
"wandb>=0.14.0",
16+
2017
]
2118

2219
[build-system]

src/config/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@ datasets:
1818
- webgpt:
1919
split: "train"
2020

21-
validation_size: 0.15
21+
validation_size: 0.15

src/datacollator.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from transformers import PreTrainedTokenizer
2-
import torch
31
from dataclasses import dataclass
42

3+
import torch
4+
from transformers import PreTrainedTokenizer
5+
56
from utils import SPECIAL_TOKENS
67

78

@@ -11,10 +12,18 @@ class RMDataCollator:
1112
max_length: int = 512
1213

1314
def format_prefix(self, prompts, eos):
14-
15-
prompts = ["{}{}{}".format(SPECIAL_TOKENS["prompter"] if i%2==0 else SPECIAL_TOKENS["assistant"], prompt, eos) for i,prompt in enumerate(prompts)]
15+
prompts = [
16+
"{}{}{}".format(
17+
SPECIAL_TOKENS["prompter"]
18+
if i % 2 == 0
19+
else SPECIAL_TOKENS["assistant"],
20+
prompt,
21+
eos,
22+
)
23+
for i, prompt in enumerate(prompts)
24+
]
1625
return "".join(prompts)
17-
26+
1827
def format_suffix(self, answer, eos):
1928
return "{}{}{}".format(SPECIAL_TOKENS["assistant"], answer, eos)
2029

@@ -24,7 +33,7 @@ def process_example(self, example):
2433
prefix, outputs = example
2534
prefix = self.format_prefix(prefix, eos)
2635
outputs = [self.format_suffix(output, eos) for output in outputs]
27-
print(prefix,outputs)
36+
print(prefix, outputs)
2837
prefix_tokens = self.tokenizer.encode(prefix)
2938
input_ids, attention_masks = [], []
3039
for output in outputs:

src/dataset.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import re
12
from collections import defaultdict
2-
from torch.utils.data import Dataset
3+
from typing import List, Union
4+
35
from datasets import load_dataset
4-
import re
5-
from typing import Union, List
66
from omegaconf import OmegaConf
7+
from torch.utils.data import Dataset
78

89

910
class HFSummary(Dataset):
@@ -16,7 +17,7 @@ def __init__(self, split: Union[List[str], str] = "train"):
1617
if isinstance(split, OmegaConf):
1718
self.split = OmegaConf.to_object(split)
1819
else:
19-
self.split = split
20+
self.split = split
2021
dataset = load_dataset(self.name, "axis", split=self.split)
2122
self.data_dict = self.prepare_axis(dataset)
2223
self.postids = list(self.data_dict.keys())
@@ -96,30 +97,36 @@ def __getitem__(self, idx):
9697
class AnthropicRLFH(Dataset):
9798
name = "Dahoas/full-hh-rlhf"
9899

99-
def __init__(self,split:Union[List[str],str]="train"):
100+
def __init__(self, split: Union[List[str], str] = "train"):
100101
super().__init__()
101102
if isinstance(split, str):
102103
split = [split]
103104
if isinstance(split, OmegaConf):
104105
self.split = OmegaConf.to_object(split)
105106
else:
106107
self.split = split
107-
dataset = load_dataset(self.name,split=self.split)
108+
dataset = load_dataset(self.name, split=self.split)
108109
self.data_dict = defaultdict(dict)
109110
id = 0
110111
for data in dataset:
111112
for item in data:
112-
dialogs = [text.replace("\n\n","").strip() for text in re.split(r'Human:|Assistant:',item["prompt"])]
113-
dialogs = [text for text in dialogs if text!=""]
114-
self.data_dict[f"prompt{id}"].update({"prompt":dialogs,
115-
"answers":[item["chosen"], item["rejected"]]})
116-
id+=1
113+
dialogs = [
114+
text.replace("\n\n", "").strip()
115+
for text in re.split(r"Human:|Assistant:", item["prompt"])
116+
]
117+
dialogs = [text for text in dialogs if text != ""]
118+
self.data_dict[f"prompt{id}"].update(
119+
{"prompt": dialogs, "answers": [item["chosen"], item["rejected"]]}
120+
)
121+
id += 1
117122

118123
self.prompt_ids = list(self.data_dict.keys())
119124

120-
def __len__(self,):
125+
def __len__(
126+
self,
127+
):
121128
return len(self.prompt_ids)
122-
129+
123130
def __getitem__(self, idx):
124-
prompt, answers = self.data_dict.get(self.prompt_ids[idx],{}).values()
125-
return prompt, answers
131+
prompt, answers = self.data_dict.get(self.prompt_ids[idx], {}).values()
132+
return prompt, answers

src/model.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from dataclasses import dataclass
2-
from turtle import hideturtle
2+
3+
import torch
4+
from torch import nn
35
from transformers import (
6+
AutoConfig,
7+
AutoModelForSequenceClassification,
48
GPTNeoXConfig,
5-
GPTNeoXPreTrainedModel,
69
GPTNeoXModel,
7-
AutoModelForSequenceClassification,
8-
AutoConfig,
10+
GPTNeoXPreTrainedModel,
911
)
10-
from torch import nn
11-
import torch
1212
from transformers.utils import ModelOutput
1313

1414

@@ -20,11 +20,13 @@ class GPTNeoxRMOuptput(ModelOutput):
2020

2121
logits: torch.FloatTensor = None
2222

23+
2324
class GPTNeoXConfigRM(GPTNeoXConfig):
2425
model_type = "rm_gptneox_config"
26+
2527
def __init__(
2628
self,
27-
pooling = "last",
29+
pooling="last",
2830
**kwargs,
2931
):
3032
super().__init__(**kwargs)
@@ -33,7 +35,7 @@ def __init__(
3335

3436
class GPTNeoXRM(GPTNeoXPreTrainedModel):
3537
config_class = GPTNeoXConfigRM
36-
"""
38+
"""
3739
Reward Model
3840
"""
3941

@@ -44,7 +46,9 @@ def __init__(
4446
super().__init__(config)
4547
self.gpt_neox = GPTNeoXModel(config)
4648
self.pooling = config.pooling
47-
hidden_size = config.hidden_size if self.pooling != "mean-max" else config.hidden_size * 2
49+
hidden_size = (
50+
config.hidden_size if self.pooling != "mean-max" else config.hidden_size * 2
51+
)
4852
self.out_layer = nn.Linear(hidden_size, 1)
4953

5054
def forward(
@@ -74,22 +78,20 @@ def forward(
7478
) / attention_mask.sum(dim=1).unsqueeze(-1)
7579
elif self.pooling == "last":
7680
if attention_mask is None:
77-
hidden_states = hidden_states[:,-1,:]
81+
hidden_states = hidden_states[:, -1, :]
7882
else:
7983
last_idx = attention_mask.cumsum(1).argmax(1)
80-
last_idx = last_idx.view(-1,1,1).expand(-1,1,hidden_states.size(-1))
81-
hidden_states = torch.gather(hidden_states,1,last_idx).squeeze(1)
84+
last_idx = last_idx.view(-1, 1, 1).expand(-1, 1, hidden_states.size(-1))
85+
hidden_states = torch.gather(hidden_states, 1, last_idx).squeeze(1)
8286
elif self.pooling == "mean-max":
8387
if attention_mask is None:
8488
mean, max = hidden_states.mean(dim=1), hidden_states.max(dim=1).values
85-
hidden_states = torch.cat([mean,max],1)
89+
hidden_states = torch.cat([mean, max], 1)
8690
else:
8791
mean = (hidden_states * attention_mask.unsqueeze(-1)).sum(
8892
dim=1
8993
) / attention_mask.sum(dim=1).unsqueeze(-1)
90-
max = (hidden_states * attention_mask.unsqueeze(-1)).max(
91-
dim=1
92-
).values
94+
max = (hidden_states * attention_mask.unsqueeze(-1)).max(dim=1).values
9395
hidden_states = torch.cat([mean, max], 1)
9496
else:
9597
raise ValueError(f"invalid pooling {self.pooling}")
@@ -103,4 +105,4 @@ def forward(
103105

104106

105107
AutoConfig.register("rm_gptneox_config", GPTNeoXConfigRM)
106-
AutoModelForSequenceClassification.register(GPTNeoXConfigRM, GPTNeoXRM)
108+
AutoModelForSequenceClassification.register(GPTNeoXConfigRM, GPTNeoXRM)

src/trainer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
from typing import Any, Dict, List, Optional, Tuple, Union
2-
import torch
3-
from torch import nn
4-
from transformers import Trainer
1+
import os
2+
from typing import Any, Dict, List, Optional, Union
3+
54
import hydra
5+
import torch
66
from hydra.utils import instantiate
77
from omegaconf import DictConfig
8-
from datacollator import RMDataCollator
9-
import os
8+
from torch import nn
9+
from transformers import Trainer
1010

11+
from datacollator import RMDataCollator
1112
from loss import RMLoss
1213
from model import GPTNeoXRM
1314
from utils import get_tokenizer, prepare_datasets

src/utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from transformers import AutoTokenizer
1+
import torch
2+
from tokenizers import pre_tokenizers
23
from torch.utils.data import ConcatDataset, random_split
4+
from transformers import AutoTokenizer
5+
36
from dataset import AnthropicRLFH, HFSummary, WebGPT
4-
import torch
57

68
SPECIAL_TOKENS = {"prompter": "|prompter|", "assistant": "|assistant|"}
79
generator = torch.Generator().manual_seed(42)
@@ -50,7 +52,8 @@ def prepare_datasets(config):
5052

5153
dataset = ConcatDataset(dataset_list)
5254
train_dataset, valid_dataset = random_split(
53-
dataset, [1 - config.validation_size, config.validation_size],
54-
generator=generator
55+
dataset,
56+
[1 - config.validation_size, config.validation_size],
57+
generator=generator,
5558
)
5659
return train_dataset, valid_dataset

0 commit comments

Comments
 (0)