forked from mosaicml/examples
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtext_data.py
More file actions
241 lines (216 loc) · 10.2 KB
/
text_data.py
File metadata and controls
241 lines (216 loc) · 10.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
# Copyright 2022 MosaicML Examples authors
# SPDX-License-Identifier: Apache-2.0
"""Build a StreamingTextDataset dataset and dataloader for training."""
import os
import sys
from itertools import islice
from typing import Any, Dict, Iterator, Optional
import transformers
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from streaming import StreamingDataset
from torch.utils.data import DataLoader
class StreamingTextDataset(StreamingDataset):
"""Generic text dataset using MosaicML's StreamingDataset.
Args:
local (str): Local dataset directory where shards are cached by split.
tokenizer_name (str): The name of the HuggingFace tokenizer to use to
tokenize samples.
max_seq_len (int): The max sequence length of each sample.
group_method (str): How to group text samples into token samples.
Supports 'truncate' or 'concat'.
remote (str, optional): Download shards from this remote path or directory. If None, this
rank and worker's partition of the dataset must all exist locally. Defaults to ``None``.
split (str, optional): Which dataset split to use, if any. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to ``False``.
predownload (int, optional): Target number of samples ahead to download the shards of while
iterating. Defaults to ``100_000``.
keep_zip (bool, optional): Whether to keep or delete the compressed file when
decompressing downloaded shards. If set to None, keep if remote is local. Defaults to
``None``.
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
download_timeout (float): Number of seconds to wait for a shard to download before raising
an exception. Defaults to ``60``.
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with resumption.
Defaults to ``None``, which is interpreted as the number of nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
"""
def __init__(self,
local: str,
tokenizer_name: str,
max_seq_len: int,
group_method: str,
remote: Optional[str] = None,
split: Optional[str] = None,
shuffle: bool = False,
predownload: Optional[int] = 100_000,
keep_zip: Optional[bool] = None,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
shuffle_seed: int = 9176,
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None):
# Validation
if group_method not in ['truncate', 'concat']:
raise ValueError(
f"group_method='{group_method}' must be one of ['truncate', 'concat']."
)
# Build Dataset
super().__init__(local=local,
remote=remote,
split=split,
shuffle=shuffle,
predownload=predownload,
keep_zip=keep_zip,
download_retry=download_retry,
download_timeout=download_timeout,
validate_hash=validate_hash,
shuffle_seed=shuffle_seed,
num_canonical_nodes=num_canonical_nodes,
batch_size=batch_size)
self.tokenizer_name = tokenizer_name
self.max_seq_len = max_seq_len
self.group_method = group_method
# Build tokenizer
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
self.tokenizer_name)
if self.tokenizer.pad_token is None:
# Some tokenizers (e.g. GPT2 tokenizer) have no padding token which causes bugs
self.tokenizer.pad_token = self.tokenizer.eos_token
# suppress warnings when using group_method='concat' and no truncation
self.tokenizer.model_max_length = int(1e30)
# How to tokenize a text sample to a token sample
def _tokenize(self, text_sample):
if self.group_method == 'truncate':
truncation = True
padding = 'max_length'
max_length = self.max_seq_len
elif self.group_method == 'concat':
truncation = False
padding = False
max_length = None
else:
raise ValueError(f"Got unknown group_method='{self.group_method}'.")
return self.tokenizer(text_sample['text'],
truncation=truncation,
padding=padding,
max_length=max_length)
# How to process a sample
def __getitem__(self, idx: int) -> Dict[str, Any]:
text_sample = super().__getitem__(idx)
token_sample = self._tokenize(text_sample)
return token_sample
# Define iterable over samples
# Usually this can be left alone and inherited directly from super()
# class StreamingDataset, but concatenating samples is custom behavior.
# If group_method=='truncate', we simply return the token sample.
# If group_method=='concat', then we keep fetching token samples until we
# fill up max_seq_len.
def __iter__(self) -> Iterator[Any]:
if self.group_method == 'truncate':
iterator = super().__iter__()
yield from iterator
elif self.group_method == 'concat':
buffer = {}
while True:
iterator = super().__iter__()
for sample in iterator:
for k, v in sample.items():
buffer[k] = buffer.get(k, []) + v
while len(buffer['input_ids']) >= self.max_seq_len:
concat_sample = {}
for k, v in buffer.items():
concat_sample[k] = v[:self.max_seq_len]
buffer[k] = v[self.max_seq_len:]
yield concat_sample
else:
raise ValueError(f"Got unknown group_method='{self.group_method}'.")
# Define length
# Usually this can be left alone and inherited directly from super() class
# Dataset, but concatenating samples is custom behavior.
# If group_method=='truncate', we simply return the # samples.
# If group_method=='concat', we repeat forever, and have no defined length.
def __len__(self) -> Optional[int]:
if self.group_method == 'truncate':
return super().__len__()
elif self.group_method == 'concat':
return None
else:
raise ValueError(f"Got unknown group_method='{self.group_method}'.")
def build_text_dataloader(cfg: DictConfig, device_batch_size: int):
assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}'
dataset = StreamingTextDataset(
local=cfg.dataset.local,
tokenizer_name=cfg.dataset.tokenizer_name,
max_seq_len=cfg.dataset.max_seq_len,
group_method=cfg.dataset.group_method,
remote=cfg.dataset.get('remote', None),
split=cfg.dataset.get('split', None),
shuffle=cfg.dataset.get('shuffle', False),
predownload=cfg.dataset.get('predownload', 100_000),
keep_zip=cfg.dataset.get('keep_zip', False),
download_retry=cfg.dataset.get('download_retry', 2),
download_timeout=cfg.dataset.get('download_timeout', 60),
validate_hash=cfg.dataset.get('validate_hash', None),
shuffle_seed=cfg.dataset.get('shuffle_seed', None),
num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', None),
batch_size=device_batch_size)
mlm_probability = cfg.get('mlm_probability', None)
collate_fn = transformers.DataCollatorForLanguageModeling(
tokenizer=dataset.tokenizer,
mlm=mlm_probability is not None,
mlm_probability=mlm_probability)
return DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=device_batch_size,
drop_last=cfg.drop_last,
num_workers=cfg.num_workers,
pin_memory=cfg.get('pin_memory', True),
prefetch_factor=cfg.get('prefetch_factor', 2),
persistent_workers=cfg.get('persistent_workers', False),
timeout=cfg.get('timeout', 0),
)
# Helpful to test if your dataloader is working locally
# Run `python data.py [remote] [local, optional]` and verify that batches are printed out
if __name__ == '__main__':
if len(sys.argv) > 2:
local, remote = sys.argv[1:3]
print(f'Reading val split from {local} <- streamed from <- {remote}')
else:
local = sys.argv[1]
remote = None
print(f'Reading val split from {local}')
cfg = {
'name': 'text',
'dataset': {
'local': local,
'remote': remote,
'split': 'val',
'shuffle': False,
'tokenizer_name': 'gpt2',
'max_seq_len': 32,
'group_method': 'truncate',
'keep_zip': True, # in case we need compressed files after testing
},
'drop_last': False,
'num_workers': 4,
}
cfg = om.create(cfg)
device_batch_size = 2
loader = build_text_dataloader(cfg, device_batch_size)
tokenizer = loader.dataset.tokenizer # type: ignore
for batch_ix, batch in enumerate(islice(loader, 5)):
print('\n')
print('#' * 20, f'Batch {batch_ix}', '#' * 20)
for k, v in batch.items():
print(k, v.shape, v.dtype)
for sample_ix, token_sample in enumerate(batch['input_ids']):
print('-' * 20, f' Sample {sample_ix} ', '-' * 20)
print(tokenizer.decode(token_sample))