Skip to content

Commit 463d729

Browse files
grgaugrgauCloseChoiceandreaskoepf
authored
Add data loader for HF oasst1 (#2951)
Make it possible to work with the OASST1 dataset directly from the HuggingFace hub. Add a new `hf_dataset_name` parameter to the `load_oasst_export` function. --------- Co-authored-by: grgau <[email protected]> Co-authored-by: Tobias Pitters <[email protected]> Co-authored-by: Andreas Köpf <[email protected]>
1 parent fff7272 commit 463d729

File tree

7 files changed

+179
-37
lines changed

7 files changed

+179
-37
lines changed

model/README.md

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,49 @@ export DATA_PATH=$PWD/.cache
1616
export MODEL_PATH=$PWD/.saved_models
1717
```
1818

19-
2. Then download the OA data.
19+
2. Then download the OA message tree JSONL file or declare the HuggingFace
20+
dataset to use.
21+
22+
Create a new or modify an existing configuration section in the `config.yaml`
23+
(SFT), `config_rm.yaml` (RM) or `config_rl.yaml` (RL) YAML configuration files
24+
located in the `model_training/configs/` directory and specify the OA JSONL data
25+
file or HuggingFace dataset to use.
26+
27+
- To use a local OASST JSONL file (either `.jsonl` or `.jsonl.gz`) specify the
28+
file name with the `input_file_path` configuration option. Place the file
29+
either in the `cache_dir` (`DATA_PATH`) or specify an absolute path.
2030

2131
```bash
22-
cp /path/to/<oa.jsonl> $DATA_PATH
32+
cp /path/to/<oasst.trees.jsonl> $DATA_PATH
33+
```
34+
35+
Example:
36+
37+
```yaml
38+
my_data_config:
39+
datasets:
40+
- oasst_export:
41+
input_file_path: oasst_export.trees.jsonl.gz
2342
```
2443
25-
Change the `<oa.jsonl>` file used in the `model_training/configs/config.yaml`,
26-
`model_training/configs/config_rl.yaml` and `reward/instructor/rank_datasets.py`
27-
files.
44+
- To use a HuggingFace dataset specify the dataset name with the
45+
`hf_dataset_name` configuration option.
46+
47+
Example:
48+
49+
```yaml
50+
my_data_config:
51+
datasets:
52+
- oasst_export:
53+
hf_dataset_name: OpenAssistant/oasst1
54+
```
55+
56+
_Note_: If both `hf_dataset_name` and `input_file_path` are specified
57+
`input_file_path` will take precedence.
58+
59+
See the
60+
[OpenAssistant/oasst1](https://huggingface.co/datasets/OpenAssistant/oasst1)
61+
dataset card on the HuggingFace hub for more information.
2862

2963
- (TODO) add better parsing of the config files that is consistent for sft, rm
3064
and rl training.

model/model_training/configs/config.yaml

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,9 @@ oasst_only:
186186
datasets:
187187
- oasst_export:
188188
lang: "bg,ca,cs,da,de,en,es,fr,hr,hu,it,nl,pl,pt,ro,ru,sl,sr,sv,uk"
189-
input_file_path: 2023-04-04_oasst_ready.jsonl.gz
189+
hf_dataset_name: OpenAssistant/oasst1
190+
#input_file_path: 2023-04-12_oasst_ready.trees.jsonl.gz
191+
#top_k: 1
190192
val_split: 0.05
191193
sort_by_length: false
192194
use_custom_sampler: false
@@ -206,14 +208,28 @@ oasst_export_eu:
206208
datasets:
207209
- oasst_export:
208210
lang: "en,es,de,fr"
209-
input_file_path: 2023-03-27_oasst_research_ready_synth.jsonl.gz
211+
hf_dataset_name: OpenAssistant/oasst1
212+
- gpt4all
213+
- alpaca
214+
- code_alpaca
215+
- oig_file:
216+
source_url: https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl
217+
max_count: 10000
218+
min_length: 100
219+
val_split: 0.1
220+
- oig_file:
221+
source_url: https://huggingface.co/datasets/laion/OIG/raw/main/unified_grade_school_math_instructions.jsonl
222+
val_split: 0.1
223+
min_length: 100
224+
sort_by_length: false
225+
use_custom_sampler: false
210226

211227
oasst_export_latin_cyrillic:
212228
save_strategy: epoch
213229
datasets:
214230
- oasst_export:
215231
lang: "bg,ca,cs,da,de,en,es,fr,hr,hu,it,nl,pl,pt,ro,ru,sl,sr,sv,uk"
216-
input_file_path: 2023-03-27_oasst_research_ready_synth.jsonl.gz
232+
hf_dataset_name: OpenAssistant/oasst1
217233
- alpaca
218234
- oig_file:
219235
source_url: https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl
@@ -364,7 +380,7 @@ llama-30b-sft-6:
364380
datasets:
365381
- oasst_export:
366382
lang: "bg,ca,cs,da,de,en,es,fr,hr,hu,it,nl,pl,pt,ro,ru,sl,sr,sv,uk"
367-
input_file_path: 2023-04-12_oasst_release_ready_synth.jsonl.gz
383+
hf_dataset_name: OpenAssistant/oasst1
368384
val_split: 0.05
369385
- vicuna:
370386
val_split: 0.05
@@ -712,6 +728,7 @@ galactica-125m:
712728
gradient_accumulation_steps: 2
713729
per_device_train_batch_size: 4
714730
per_device_eval_batch_size: 4
731+
dtype: fp32
715732

716733
gpt-jt:
717734
learning_rate: 8e-6
@@ -761,3 +778,4 @@ debug:
761778
log_wandb: false
762779
verbose: true
763780
num_train_epochs: 0.2
781+
dtype: fp32

model/model_training/configs/config_rm.yaml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ oasst-rm-1-pythia-6.9b:
4949
pooling: last
5050
datasets:
5151
- augment_oasst:
52-
#input_file_path: augmented_latin_cyrillic_oasst_2023-03-27.jsonl
5352
input_file_path: augmented_latin_cyrillic_oasst_2023-03-27_v2.jsonl
5453
- anthropic_rlhf:
5554
fraction: 0.1
@@ -98,10 +97,9 @@ oasst-rm-1-pythia-2.8b:
9897
datasets:
9998
- oasst_export:
10099
lang: "en,es,de,fr"
101-
input_file_path: 2023-03-27_oasst_research_ready_synth.jsonl.gz
100+
hf_dataset_name: OpenAssistant/oasst1
102101
val_split: 0.1
103102
- augment_oasst:
104-
#input_file_path: augmented_latin_cyrillic_oasst_2023-03-27.jsonl
105103
input_file_path: augmented_latin_cyrillic_oasst_2023-03-27_v2.jsonl
106104
- anthropic_rlhf:
107105
fraction: 0.1
@@ -142,7 +140,7 @@ oasst-rm-1-pythia-1.4b:
142140
datasets:
143141
- oasst_export:
144142
lang: "en,es,de,fr"
145-
input_file_path: 2023-03-27_oasst_research_ready_synth.jsonl.gz
143+
hf_dataset_name: OpenAssistant/oasst1
146144
val_split: 0.1
147145
- augment_oasst:
148146
input_file_path: augmented_latin_cyrillic_oasst_2023-03-27.jsonl

model/model_training/custom_datasets/oasst_dataset.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from pathlib import Path
2-
from typing import Literal, Optional
2+
from typing import Iterable, Literal, Optional
33

44
from model_training.custom_datasets.formatting import DatasetEntrySft, Role, Utterance
5-
from oasst_data import ExportMessageNode, read_message_trees, visit_threads_depth_first
5+
from oasst_data import ExportMessageNode, read_dataset_message_trees, read_message_trees, visit_threads_depth_first
6+
from oasst_data.schemas import ExportMessageTree
67
from torch import Generator
78
from torch.utils.data import Dataset, random_split
89

@@ -20,7 +21,8 @@ def __getitem__(self, index):
2021

2122

2223
def load_oasst_export(
23-
input_file_path: str | Path,
24+
input_file_path: Optional[str | Path] = None,
25+
hf_dataset_name: Optional[str] = "OpenAssistant/oasst1",
2426
val_split: float = 0.2,
2527
lang: str = "en",
2628
top_k: Optional[int] = None,
@@ -31,20 +33,27 @@ def load_oasst_export(
3133
if mode not in ("sft", "rm", "rl"):
3234
raise ValueError(f"Unknown dataset mode: {mode}")
3335

34-
lang_codes = lang.split(",")
36+
lang_codes: list[str] = lang.split(",")
3537

3638
generator = Generator()
3739
generator.manual_seed(manual_seed)
3840

39-
if not isinstance(input_file_path, Path):
40-
input_file_path = Path(input_file_path)
41-
if not input_file_path.is_absolute() and data_path:
42-
if not isinstance(data_path, Path):
43-
data_path = Path(data_path)
44-
input_file_path = data_path / input_file_path
41+
tree_iter: Iterable[ExportMessageTree] = None
42+
if input_file_path:
43+
if not isinstance(input_file_path, Path):
44+
input_file_path = Path(input_file_path)
45+
if not input_file_path.is_absolute() and data_path:
46+
if not isinstance(data_path, Path):
47+
data_path = Path(data_path)
48+
input_file_path = data_path / input_file_path
49+
tree_iter = read_message_trees(input_file_path)
50+
elif hf_dataset_name:
51+
tree_iter = read_dataset_message_trees(hf_dataset_name, split="train+validation")
52+
else:
53+
raise RuntimeError("Either `input_file_path` or `hf_dataset_name` must be specified.")
4554

4655
threads_per_tree = []
47-
for tree in read_message_trees(input_file_path):
56+
for tree in tree_iter:
4857
if tree.tree_state != "ready_for_export" or not tree.prompt.review_result or tree.prompt.lang not in lang_codes:
4958
continue
5059

@@ -145,6 +154,9 @@ def flatten(ds: ListDataset) -> ListDataset:
145154
train = flatten(splits[0])
146155
val = flatten(splits[1])
147156

148-
print(f"OASST data {str(input_file_path)}: {len(train)=}, {len(val)=}")
157+
if input_file_path:
158+
print(f"OASST JSONL file {str(input_file_path)}: {len(train)=}, {len(val)=}")
159+
else:
160+
print(f"OASST HF dataset {hf_dataset_name}: {len(train)=}, {len(val)=}")
149161

150162
return train, val

oasst-data/oasst_data/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
from oasst_data.reader import read_message_list, read_message_tree_list, read_message_trees, read_messages
1+
from oasst_data.reader import (
2+
read_dataset_message_trees,
3+
read_dataset_messages,
4+
read_message_list,
5+
read_message_tree_list,
6+
read_message_trees,
7+
read_messages,
8+
)
29
from oasst_data.schemas import (
310
ExportMessageEvent,
411
ExportMessageEventEmoji,
@@ -33,4 +40,6 @@
3340
"visit_messages_depth_first",
3441
"write_message_trees",
3542
"write_messages",
43+
"read_dataset_message_trees",
44+
"read_dataset_messages",
3645
]

oasst-data/oasst_data/reader.py

Lines changed: 80 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Callable, Iterable, Optional, TextIO
55

66
import pydantic
7+
from datasets import load_dataset
78

89
from .schemas import ExportMessageNode, ExportMessageTree
910

@@ -17,22 +18,24 @@ def open_jsonl_read(input_file_path: str | Path) -> TextIO:
1718
return input_file_path.open("r", encoding="UTF-8")
1819

1920

20-
def read_oasst_obj(line: str) -> ExportMessageTree | ExportMessageNode:
21-
dict_tree = json.loads(line)
21+
def read_oasst_obj(obj_dict: dict) -> ExportMessageTree | ExportMessageNode:
2222
# validate data
23-
if "message_id" in dict_tree:
24-
return pydantic.parse_obj_as(ExportMessageNode, dict_tree)
25-
elif "message_tree_id" in dict_tree:
26-
return pydantic.parse_obj_as(ExportMessageTree, dict_tree)
23+
if "message_id" in obj_dict:
24+
return pydantic.parse_obj_as(ExportMessageNode, obj_dict)
25+
elif "message_tree_id" in obj_dict:
26+
return pydantic.parse_obj_as(ExportMessageTree, obj_dict)
2727

2828
raise RuntimeError("Unknown object in jsonl file")
2929

3030

31-
def read_oasst_jsonl(input_file_path: str | Path) -> Iterable[ExportMessageTree | ExportMessageNode]:
31+
def read_oasst_jsonl(
32+
input_file_path: str | Path,
33+
) -> Iterable[ExportMessageTree | ExportMessageNode]:
3234
with open_jsonl_read(input_file_path) as file_in:
3335
# read one object per line
3436
for line in file_in:
35-
yield read_oasst_obj(line)
37+
dict_tree = json.loads(line)
38+
yield read_oasst_obj(dict_tree)
3639

3740

3841
def read_message_trees(input_file_path: str | Path) -> Iterable[ExportMessageTree]:
@@ -42,18 +45,85 @@ def read_message_trees(input_file_path: str | Path) -> Iterable[ExportMessageTre
4245

4346

4447
def read_message_tree_list(
45-
input_file_path: str | Path, filter: Optional[Callable[[ExportMessageTree], bool]] = None
48+
input_file_path: str | Path,
49+
filter: Optional[Callable[[ExportMessageTree], bool]] = None,
4650
) -> list[ExportMessageTree]:
4751
return [t for t in read_message_trees(input_file_path) if not filter or filter(t)]
4852

4953

54+
def convert_hf_message(row: dict) -> None:
55+
emojis = row.get("emojis")
56+
if emojis:
57+
row["emojis"] = dict(zip(emojis["name"], emojis["count"]))
58+
labels = row.get("labels")
59+
if labels:
60+
row["labels"] = {
61+
name: {"value": value, "count": count}
62+
for name, value, count in zip(labels["name"], labels["value"], labels["count"])
63+
}
64+
65+
5066
def read_messages(input_file_path: str | Path) -> Iterable[ExportMessageNode]:
5167
for x in read_oasst_jsonl(input_file_path):
5268
assert isinstance(x, ExportMessageNode)
5369
yield x
5470

5571

5672
def read_message_list(
57-
input_file_path: str | Path, filter: Optional[Callable[[ExportMessageNode], bool]] = None
73+
input_file_path: str | Path,
74+
filter: Optional[Callable[[ExportMessageNode], bool]] = None,
5875
) -> list[ExportMessageNode]:
5976
return [t for t in read_messages(input_file_path) if not filter or filter(t)]
77+
78+
79+
def read_dataset_message_trees(
80+
hf_dataset_name: str = "OpenAssistant/oasst1",
81+
split: str = "train+validation",
82+
) -> Iterable[ExportMessageTree]:
83+
dataset = load_dataset(hf_dataset_name, split=split)
84+
85+
tree_dict: dict = None
86+
parents: list = None
87+
for row in dataset:
88+
convert_hf_message(row)
89+
if row["parent_id"] is None:
90+
if tree_dict:
91+
tree = read_oasst_obj(tree_dict)
92+
assert isinstance(tree, ExportMessageTree)
93+
yield tree
94+
95+
tree_dict = {
96+
"message_tree_id": row["message_id"],
97+
"tree_state": row["tree_state"],
98+
"prompt": row,
99+
}
100+
parents = []
101+
else:
102+
while parents[-1]["message_id"] != row["parent_id"]:
103+
parents.pop()
104+
parent = parents[-1]
105+
if "replies" not in parent:
106+
parent["replies"] = []
107+
parent["replies"].append(row)
108+
109+
row.pop("message_tree_id", None)
110+
row.pop("tree_state", None)
111+
parents.append(row)
112+
113+
if tree_dict:
114+
tree = read_oasst_obj(tree_dict)
115+
assert isinstance(tree, ExportMessageTree)
116+
yield tree
117+
118+
119+
def read_dataset_messages(
120+
hf_dataset_name: str = "OpenAssistant/oasst1",
121+
split: str = "train+validation",
122+
) -> Iterable[ExportMessageNode]:
123+
dataset = load_dataset(hf_dataset_name, split=split)
124+
125+
for row in dataset:
126+
convert_hf_message(row)
127+
message = read_oasst_obj(row)
128+
assert isinstance(message, ExportMessageNode)
129+
yield message

oasst-data/pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ authors = [
77
]
88
dependencies = [
99
"pydantic>=1.10.4",
10-
"loguru==0.6.0"
10+
"loguru==0.6.0",
11+
"datasets>=2.12.0"
1112
]
1213

1314
[project.optional-dependencies]

0 commit comments

Comments
 (0)