Skip to content

Commit 1c10aba

Browse files
authored
Integrate HF Datasets and add DatasetTuple (#1612)
* fix bart perf * update fastergeneration doc * add img * add img * change img * update img * fix img * update docs * fix readme * update readme * fix perf * fix perf * fix modelling * fix perf and sample code * fix perf * fix perf * fix seq_len for gpt_sample * add forced eos token id for faster * upgrade perf and add forced eos token id * chenge stack to gather * add auto perf * minor fix * remove encoder change * Update bart_perf.py * Update bart_perf.py * 1. Integrate HF Datasets 2. return all splits by default 3. load_dataset returns DatasetTuple now * add HF Dataset example * add kwargs for HF load_dataset * change datasets to alternative * remove experimental
1 parent 32d01fa commit 1c10aba

File tree

2 files changed

+137
-73
lines changed

2 files changed

+137
-73
lines changed

paddlenlp/datasets/dataset.py

Lines changed: 136 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import warnings
2121
import sys
2222
import inspect
23+
from collections import namedtuple
2324
from multiprocess import Pool, RLock
2425
import time
2526

@@ -37,6 +38,27 @@
3738
DATASETS_MODULE_PATH = "paddlenlp.datasets."
3839

3940

41+
class DatasetTuple:
42+
def __init__(self, splits):
43+
self.tuple_cls = namedtuple('datasets', splits)
44+
self.tuple = self.tuple_cls(* [None for _ in splits])
45+
46+
def __getitem__(self, key):
47+
if isinstance(key, (int, slice)):
48+
return self.tuple[key]
49+
if isinstance(key, str):
50+
return getattr(self.tuple, key)
51+
52+
def __repr__(self):
53+
return self.tuple.__repr__()
54+
55+
def __setitem__(self, key, value):
56+
self.tuple = self.tuple._replace(**{key: value})
57+
58+
def __len__(self):
59+
return len(self.tuple)
60+
61+
4062
def import_main_class(module_path):
4163
"""
4264
Import a module at module_path and return its DatasetBuilder class.
@@ -58,6 +80,40 @@ def import_main_class(module_path):
5880
return module_main_cls
5981

6082

83+
def load_from_hf(path, name=None, splits=None, **kwargs):
84+
from datasets import load_dataset as load_hf_dataset
85+
from datasets import DatasetDict
86+
from datasets.features import ClassLabel
87+
try:
88+
hf_datasets = load_hf_dataset(path, name=name, split=splits, **kwargs)
89+
except FileNotFoundError:
90+
raise FileNotFoundError("Couldn't find the dataset script for '" + path
91+
+ "' on PaddleNLP or HuggingFace")
92+
else:
93+
label_list = []
94+
if isinstance(hf_datasets, DatasetDict):
95+
datasets = DatasetTuple(hf_datasets.keys())
96+
for split, ds in hf_datasets.items():
97+
for feature in ds.features.values():
98+
if isinstance(feature, ClassLabel):
99+
label_list = feature.names
100+
datasets[split] = MapDataset(ds, label_list=label_list)
101+
elif isinstance(hf_datasets, list):
102+
datasets = DatasetTuple(splits)
103+
for i, split in enumerate(splits):
104+
for feature in hf_datasets[i].features.values():
105+
if isinstance(feature, ClassLabel):
106+
label_list = feature.names
107+
datasets[split] = MapDataset(
108+
hf_datasets[i], label_list=label_list)
109+
else:
110+
for feature in hf_datasets.features.values():
111+
if isinstance(feature, ClassLabel):
112+
label_list = feature.names
113+
datasets = MapDataset(hf_datasets, label_list=label_list)
114+
return datasets
115+
116+
61117
def load_dataset(path_or_read_func,
62118
name=None,
63119
data_files=None,
@@ -109,37 +165,43 @@ def load_dataset(path_or_read_func,
109165
reader_instance = SimpleBuilder(lazy=lazy, read_func=path_or_read_func)
110166
return reader_instance.read(**custom_kwargs)
111167
else:
112-
reader_cls = import_main_class(path_or_read_func)
113-
reader_instance = reader_cls(lazy=lazy, name=name, **kwargs)
168+
try:
169+
reader_cls = import_main_class(path_or_read_func)
170+
except ModuleNotFoundError:
171+
datasets = load_from_hf(
172+
path_or_read_func, name=name, splits=splits, **kwargs)
173+
else:
174+
reader_instance = reader_cls(lazy=lazy, name=name, **kwargs)
114175

115-
# Check if selected name and split is valid in this DatasetBuilder
116-
if hasattr(reader_instance, 'BUILDER_CONFIGS'):
117-
if name in reader_cls.BUILDER_CONFIGS.keys():
118-
split_names = reader_cls.BUILDER_CONFIGS[name]['splits'].keys()
176+
# Check if selected name and split is valid in this DatasetBuilder
177+
if hasattr(reader_instance, 'BUILDER_CONFIGS'):
178+
if name in reader_cls.BUILDER_CONFIGS.keys():
179+
split_names = reader_cls.BUILDER_CONFIGS[name][
180+
'splits'].keys()
181+
else:
182+
raise ValueError(
183+
'Invalid name "{}". Should be one of {}.'.format(
184+
name, list(reader_cls.BUILDER_CONFIGS.keys())))
185+
elif hasattr(reader_instance, 'SPLITS'):
186+
split_names = reader_instance.SPLITS.keys()
119187
else:
120-
raise ValueError(
121-
'Invalid name "{}". Should be one of {}.'.format(
122-
name, list(reader_cls.BUILDER_CONFIGS.keys())))
123-
elif hasattr(reader_instance, 'SPLITS'):
124-
split_names = reader_instance.SPLITS.keys()
125-
else:
126-
raise AttributeError(
127-
"Either 'SPLITS' or 'BUILDER_CONFIGS' must be implemented for DatasetBuilder."
128-
)
188+
raise AttributeError(
189+
"Either 'SPLITS' or 'BUILDER_CONFIGS' must be implemented for DatasetBuilder."
190+
)
129191

130-
selected_splits = []
131-
if isinstance(splits, list) or isinstance(splits, tuple):
132-
selected_splits.extend(splits)
133-
else:
134-
selected_splits += [splits]
192+
selected_splits = []
193+
if isinstance(splits, list) or isinstance(splits, tuple):
194+
selected_splits.extend(splits)
195+
else:
196+
selected_splits += [splits]
135197

136-
for split_name in selected_splits:
137-
if split_name not in split_names and split_name != None:
138-
raise ValueError('Invalid split "{}". Should be one of {}.'.
139-
format(split_name, list(split_names)))
198+
for split_name in selected_splits:
199+
if split_name not in split_names and split_name != None:
200+
raise ValueError('Invalid split "{}". Should be one of {}.'.
201+
format(split_name, list(split_names)))
140202

141-
datasets = reader_instance.read_datasets(
142-
data_files=data_files, splits=splits)
203+
datasets = reader_instance.read_datasets(
204+
data_files=data_files, splits=splits)
143205
return datasets
144206

145207

@@ -163,9 +225,9 @@ def __init__(self, data, **kwargs):
163225
self.data = data
164226
self._transform_pipline = []
165227
self.new_data = self.data
166-
167-
self.label_list = kwargs.pop('label_list', None)
168-
self.vocab_info = kwargs.pop('vocab_info', None)
228+
self.info = kwargs
229+
self.label_list = self.info.pop('label_list', None)
230+
self.vocab_info = self.info.pop('vocab_info', None)
169231

170232
def _transform(self, data):
171233
for fn in self._transform_pipline:
@@ -198,23 +260,22 @@ def filter(self, fn, num_workers=0):
198260
set to 0, it doesn't use multiprocessing. Defaults to `0`.
199261
"""
200262
assert num_workers >= 0, "num_workers should be a non-negative value"
201-
if num_workers > 0:
202-
pool = Pool(
203-
num_workers, initargs=(RLock(), ), maxtasksperchild=1000)
204-
205-
def filter_shard(num_workers, index, fn):
206-
self.shard(num_shards=num_workers, index=index, contiguous=True)
207-
self._filter(fn=fn)
208-
return self
209-
263+
if num_workers > 1:
264+
shards = [
265+
self._shard(
266+
num_shards=num_workers, index=index, contiguous=True)
267+
for index in range(num_workers)
268+
]
210269
kwds_per_shard = [
211270
dict(
212-
num_workers=num_workers, index=rank, fn=fn)
213-
for rank in range(num_workers)
271+
self=shards[rank], fn=fn) for rank in range(num_workers)
214272
]
273+
pool = Pool(num_workers, initargs=(RLock(), ))
274+
215275
results = [
216276
pool.apply_async(
217-
filter_shard, kwds=kwds) for kwds in kwds_per_shard
277+
self.__class__._filter, kwds=kwds)
278+
for kwds in kwds_per_shard
218279
]
219280
transformed_shards = [r.get() for r in results]
220281

@@ -235,6 +296,11 @@ def _filter(self, fn):
235296
return self
236297

237298
def shard(self, num_shards=None, index=None, contiguous=False):
299+
self.new_data = self._shard(
300+
num_shards=num_shards, index=index, contiguous=contiguous).data
301+
return self
302+
303+
def _shard(self, num_shards=None, index=None, contiguous=False):
238304
"""
239305
Split the dataset into `num_shards` pieces. Note that the size of each
240306
shard might be different because the original dataset may not be evenly
@@ -262,15 +328,14 @@ def shard(self, num_shards=None, index=None, contiguous=False):
262328
mod = len(self) % num_shards
263329
start = div * index + min(index, mod)
264330
end = start + div + (1 if index < mod else 0)
265-
self.new_data = self.new_data[start:end]
331+
new_data = [self.new_data[idx] for idx in range(start, end)]
266332
else:
267-
num_samples = int(math.ceil(len(self.new_data) * 1.0 / num_shards))
268-
self.new_data = [
333+
new_data = [
269334
self.new_data[idx] for idx in range(len(self.new_data))
270335
if idx % num_shards == index
271336
]
272337

273-
return self
338+
return MapDataset(new_data)
274339

275340
def map(self, fn, lazy=True, batched=False, num_workers=0):
276341
"""
@@ -292,25 +357,22 @@ def map(self, fn, lazy=True, batched=False, num_workers=0):
292357
"""
293358

294359
assert num_workers >= 0, "num_workers should be a non-negative value"
295-
if num_workers > 0:
296-
297-
def map_shard(num_workers, index, fn, batched):
298-
self.shard(num_shards=num_workers, index=index, contiguous=True)
299-
self._map(fn=fn, lazy=False, batched=batched)
300-
return self
301-
360+
if num_workers > 1:
361+
shards = [
362+
self._shard(
363+
num_shards=num_workers, index=index, contiguous=True)
364+
for index in range(num_workers)
365+
]
302366
kwds_per_shard = [
303367
dict(
304-
num_workers=num_workers, index=rank, fn=fn, batched=batched)
368+
self=shards[rank], fn=fn, lazy=False, batched=batched)
305369
for rank in range(num_workers)
306370
]
307-
pool = Pool(
308-
num_workers, initargs=(RLock(), ), maxtasksperchild=1000)
371+
pool = Pool(num_workers, initargs=(RLock(), ))
309372
results = [
310373
pool.apply_async(
311-
map_shard, kwds=kwds) for kwds in kwds_per_shard
374+
self.__class__._map, kwds=kwds) for kwds in kwds_per_shard
312375
]
313-
314376
transformed_shards = [r.get() for r in results]
315377
pool.close()
316378
pool.join()
@@ -471,9 +533,6 @@ def __init__(self, lazy=None, name=None, **config):
471533
self.config = config
472534

473535
def read_datasets(self, splits=None, data_files=None):
474-
datasets = []
475-
assert splits or data_files, "`data_files` and `splits` can not both be None."
476-
477536
def remove_if_exit(filepath):
478537
if isinstance(filepath, (list, tuple)):
479538
for file in filepath:
@@ -487,14 +546,21 @@ def remove_if_exit(filepath):
487546
except OSError:
488547
pass
489548

490-
if splits and data_files is None:
549+
if data_files is None:
550+
if splits is None:
551+
splits = list(self.BUILDER_CONFIGS[self.name]['splits'].keys(
552+
)) if hasattr(self,
553+
"BUILDER_CONFIGS") else list(self.SPLITS.keys())
554+
491555
assert isinstance(splits, str) or (
492556
isinstance(splits, list) and isinstance(splits[0], str)
493557
) or (
494558
isinstance(splits, tuple) and isinstance(splits[0], str)
495559
), "`splits` should be a string or list of string or a tuple of string."
560+
496561
if isinstance(splits, str):
497562
splits = [splits]
563+
datasets = DatasetTuple(splits)
498564
parallel_env = dist.ParallelEnv()
499565
unique_endpoints = _get_unique_endpoints(
500566
parallel_env.trainer_endpoints[:])
@@ -526,34 +592,31 @@ def remove_if_exit(filepath):
526592
else:
527593
while not os.path.exists(lock_file):
528594
time.sleep(1)
529-
datasets.append(self.read(filename=filename, split=split))
530-
531-
if data_files:
595+
datasets[split] = self.read(filename=filename, split=split)
596+
else:
532597
assert isinstance(data_files, str) or isinstance(
533598
data_files, tuple) or isinstance(
534599
data_files, list
535600
), "`data_files` should be a string or tuple or list of strings."
536-
537601
if isinstance(data_files, str):
538602
data_files = [data_files]
539603
default_split = 'train'
540604
if splits:
541605
if isinstance(splits, str):
542606
splits = [splits]
607+
datasets = DatasetTuple(splits)
543608
assert len(splits) == len(
544609
data_files
545610
), "Number of `splits` and number of `data_files` should be the same if you want to specify the split of loacl data file."
546-
datasets += [
547-
self.read(
611+
for i in range(len(data_files)):
612+
datasets[splits[i]] = self.read(
548613
filename=data_files[i], split=splits[i])
549-
for i in range(len(data_files))
550-
]
551614
else:
552-
datasets += [
553-
self.read(
615+
datasets = DatasetTuple(
616+
["split" + str(i) for i in range(len(data_files))])
617+
for i in range(len(data_files)):
618+
datasets["split" + str(i)] = self.read(
554619
filename=data_files[i], split=default_split)
555-
for i in range(len(data_files))
556-
]
557620

558621
return datasets if len(datasets) > 1 else datasets[0]
559622

paddlenlp/ops/faster_transformer/transformer/faster_transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,6 +1252,7 @@ def forward(self,
12521252
temperature=1.0,
12531253
num_return_sequences=1,
12541254
early_stopping=False,
1255+
forced_eos_token_id=None,
12551256
**model_kwargs):
12561257

12571258
bos_token_id = bos_token_id if bos_token_id is not None else getattr(

0 commit comments

Comments
 (0)