Skip to content

Commit d2992b7

Browse files
authored
Hugging face dataset streaming support (#177)
* streaming dataset * no stride / offset for streaming * add recipe example for streaming
1 parent 02a9dbe commit d2992b7

File tree

5 files changed

+229
-1
lines changed

5 files changed

+229
-1
lines changed

eole/config/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def _validate_vocab_config(self, build_vocab_only=False):
210210
@staticmethod
211211
def _validate_file(file_path, info):
212212
"""Check `file_path` is valid or raise `IOError`."""
213-
if file_path == "dummy":
213+
if file_path == "dummy" or file_path.startswith("hf://"):
214214
# hack to allow creating objects with required fields
215215
pass
216216
elif not os.path.isfile(file_path):

eole/inputters/text_corpus.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
"""Module that contain shard utils for dynamic data."""
22

33
import os
4+
import re
45
from eole.utils.logging import logger
56
from eole.constants import CorpusName, CorpusTask
67
from eole.transforms import TransformPipe
78
from eole.inputters.text_utils import transform_bucket
89
from contextlib import contextmanager
910
import itertools
11+
from datasets import load_dataset
1012

1113

1214
@contextmanager
@@ -102,11 +104,38 @@ def __init__(self, name, src, tgt, sco=None, align=None):
102104
self.sco = sco
103105
self.align = align
104106

107+
def _is_hf_dataset(self, path):
108+
"""
109+
Check if a given path refers to a Hugging Face dataset.
110+
Matchs the 'hf://' prefix and assumes the dataset is in streaming mode.
111+
Match the last '/field' to get the language / score field
112+
"""
113+
pattern = r"hf://([^/]+/[^/]+)/([^/]+)"
114+
if isinstance(path, str):
115+
return re.match(pattern, path)
116+
else:
117+
return None
118+
119+
def _load_hf_dataset(self, path):
120+
"""
121+
Load a Hugging Face dataset from the given identifier.
122+
Matchs the 'hf://' prefix and assumes the dataset is in streaming mode.
123+
Match the last '/field' to get the language / score field
124+
"""
125+
pattern = r"hf://([^/]+/[^/]+)/([^/]+)"
126+
dataset_name = re.match(pattern, self.src).group(1)
127+
return load_dataset(dataset_name, split="train", streaming=True)
128+
105129
def load(self, offset=0, stride=1):
106130
"""
107131
Load file and iterate by lines.
108132
`offset` and `stride` allow to iterate only on every
109133
`stride` example, starting from `offset`.
134+
In the case of local files, all files are open exactly the same way by each worker
135+
Therefore we need to apply a stride / offset rule to make sure we do not process the same ex.
136+
In the case of HF streaming mode we need to make sure we have more shards than workers.
137+
Typically we recommend to have shard being a multiple of workers for instance for big datasets:
138+
16 shards for 4 workers. The shards will be iterated automatically since HF locks shards when in use.
110139
"""
111140

112141
def make_ex(sline, tline, scoline, align):
@@ -133,6 +162,16 @@ def make_ex(sline, tline, scoline, align):
133162
if scoline is None:
134163
scoline = 1.0
135164
yield make_ex(sline, tline, scoline, align)
165+
166+
elif self._is_hf_dataset(self.src):
167+
# If `src` is a Hugging Face dataset identifier
168+
dataset = self._load_hf_dataset(self.src)
169+
for i, example in enumerate(dataset):
170+
sline = example.get(self.src.split("/")[-1])
171+
tline = example.get(self.tgt.split("/")[-1])
172+
scoline = example.get(self.sco.split("/")[-1], 1.0)
173+
yield make_ex(sline, tline, scoline, None)
174+
136175
else:
137176
with exfile_open(self.src, mode="rb") as fs, exfile_open(self.tgt, mode="rb") as ft, exfile_open(
138177
self.sco, mode="rb"
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Example of using Hugging Face streaming dataset
2+
3+
## Based on: https://arxiv.org/pdf/2408.06537
4+
5+
Introducing the NewsPaLM MBR and QE Dataset:
6+
LLM-Generated High-Quality Parallel Data Outperforms Traditional
7+
Web-Crawled Data
8+
9+
### Get the vocab and BPE model on HF
10+
11+
https://huggingface.co/eole-nlp/NewsPalmSynthetic-ENDE
12+
13+
copy files:
14+
* ende.vocab2
15+
* subwords.en_de.bpe
16+
17+
18+
### Optionally you can get the trained model to test it.
19+
20+
21+
* config.json
22+
* vocab.json
23+
* model.00.safetensors
24+
25+
## Train with the yaml config file
26+
27+
```
28+
eole train -c newspalm-synthetic-hfstreaming.yaml
29+
```
30+
31+
## Start the gradio based translator
32+
33+
```
34+
eole predict -c inference.yaml --src newstest2023-src.en --output newstest2023-hyp.de
35+
```
36+
37+
Then you can score with sacrebleu and/or comet
38+
39+
40+
Scoring with Unbabel/wmt22-comet-da gives: 81.90
41+
42+
You can compare to table 5 lines 2a) to 2d) of the paper https://arxiv.org/pdf/2408.06537
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Model info
2+
model_path: "Path to model.00.safetensors"
3+
# Inference
4+
max_length: 1024
5+
max_length_ratio: 3
6+
world_size: 1
7+
gpu_ranks: [0]
8+
batch_type: tokens
9+
batch_size: 16384
10+
compute_dtype: fp16
11+
beam_size: 4
12+
n_best: 1
13+
report_time: true
14+
self_attn_backend: "pytorch"
15+
src: none
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
seed: 1234
2+
share_vocab: true
3+
src_vocab: "ende.vocab2"
4+
5+
src_words_min_frequency: 1
6+
vocab_size_multiple: 8
7+
report_every: 100
8+
skip_empty_level: silent
9+
valid_metrics: ["BLEU"]
10+
scoring_debug: True
11+
12+
# transforms config
13+
transforms: [onmt_tokenize, filtertoolong]
14+
transforms_configs:
15+
onmt_tokenize:
16+
#### Subword
17+
src_subword_type: bpe
18+
src_subword_model: "subwords.en_de.bpe"
19+
src_onmttok_kwargs: {"mode": "aggressive", "joiner_annotate": True, "preserve_placeholders": True, "case_markup": True, "soft_case_regions": True, "preserve_segmented_tokens": True, "segment_case": True, "segment_numbers": True, "segment_alphabet_change": True}
20+
tgt_subword_type: bpe
21+
tgt_subword_model: "subwords.en_de.bpe"
22+
tgt_onmttok_kwargs: {"mode": "aggressive", "joiner_annotate": True, "preserve_placeholders": True, "case_markup": True, "soft_case_regions": True, "preserve_segmented_tokens": True, "segment_case": True, "segment_numbers": True, "segment_alphabet_change": True}
23+
24+
filtertoolong:
25+
src_seq_length: 1024
26+
tgt_seq_length: 1024
27+
28+
# Corpus opts:
29+
data:
30+
synth-mbr-decoded-sentlevel:
31+
# 997834 ex - 315MB
32+
path_src: "hf://eole-nlp/synth-mbr-decoded-sentlevel/en"
33+
path_tgt: "hf://eole-nlp/synth-mbr-decoded-sentlevel/de"
34+
path_sco: "hf://eole-nlp/synth-mbr-decoded-sentlevel/sco"
35+
transforms: [onmt_tokenize, filtertoolong]
36+
weight: 12
37+
38+
synth-greedy-decoded-sentlevel:
39+
# 832709 ex - 250MB
40+
path_src: "hf://eole-nlp/synth-greedy-decoded-sentlevel/en"
41+
path_tgt: "hf://eole-nlp/synth-greedy-decoded-sentlevel/de"
42+
path_sco: "hf://eole-nlp/synth-greedy-decoded-sentlevel/sco"
43+
transforms: [onmt_tokenize, filtertoolong]
44+
weight: 10
45+
46+
synth-qe-reranked-doclevel:
47+
# 417102 ex - 970MB
48+
path_src: "hf://eole-nlp/synth-qe-reranked-doclevel/en"
49+
path_tgt: "hf://eole-nlp/synth-qe-reranked-doclevel/de"
50+
path_sco: "hf://eole-nlp/synth-qe-reranked-doclevel/sco"
51+
transforms: [onmt_tokenize, filtertoolong]
52+
weight: 1
53+
54+
synth-greedy-decoded-doclevel:
55+
# 857937 ex - 1.7GB
56+
path_src: "hf://eole-nlp/europarl-v10.de-en/en"
57+
path_tgt: "hf://eole-nlp/europarl-v10.de-en/de"
58+
path_sco: "hf://eole-nlp/europarl-v10.de-en/sco"
59+
transforms: [onmt_tokenize, filtertoolong]
60+
weight: 2
61+
62+
valid:
63+
path_src: "newstest2023-src.en"
64+
path_tgt: "newstest2023-ref.de"
65+
transforms: [onmt_tokenize]
66+
67+
training:
68+
# General opts
69+
torch_compile: false
70+
71+
model_path: "6-6-16-1024-4096-hfstreaming"
72+
keep_checkpoint: 50
73+
save_checkpoint_steps: 5000
74+
average_decay: 0.0005
75+
train_steps: 51000
76+
valid_steps: 100
77+
78+
# Batching
79+
bucket_size: 10000
80+
num_workers: 4
81+
prefetch_factor: 400
82+
world_size: 1
83+
gpu_ranks: [0]
84+
batch_type: "tokens"
85+
batch_size: 12144
86+
valid_batch_size: 8192
87+
batch_size_multiple: 1
88+
accum_count: [6, 6, 6]
89+
accum_steps: [0, 15000, 30000]
90+
91+
# Optimization
92+
compute_dtype: "fp16"
93+
apex_opt_level: ""
94+
optim: "adamw"
95+
reset_optim: "all"
96+
learning_rate: 1
97+
warmup_steps: 6000
98+
decay_method: "noam"
99+
adam_beta2: 0.998
100+
max_grad_norm: 1
101+
label_smoothing: 0.1
102+
param_init_method: "xavier_uniform"
103+
normalization: "tokens"
104+
105+
dropout_steps: [0, 15000, 30000]
106+
dropout: [0.1, 0.1, 0.1]
107+
attention_dropout: [0.0, 0.0, 0.0]
108+
score_threshold: 0.65
109+
110+
freeze_decoder: false
111+
freeze_encoder: false
112+
113+
model:
114+
architecture: "transformer"
115+
layers: 6
116+
hidden_size: 1024
117+
heads: 16
118+
transformer_ff: 4096
119+
add_qkvbias: false
120+
add_ffnbias: true
121+
mlp_activation_fn: gated-silu
122+
add_estimator: false
123+
share_decoder_embeddings: true
124+
share_embeddings: true
125+
layer_norm: standard
126+
norm_eps: 1e-6
127+
rope_config:
128+
rotary_interleave: false
129+
embeddings:
130+
word_vec_size: 1024
131+
position_encoding_type: "Rotary"
132+
freeze_word_vecs_dec: false

0 commit comments

Comments
 (0)