20
20
import warnings
21
21
import sys
22
22
import inspect
23
+ from collections import namedtuple
23
24
from multiprocess import Pool , RLock
24
25
import time
25
26
37
38
DATASETS_MODULE_PATH = "paddlenlp.datasets."
38
39
39
40
41
+ class DatasetTuple :
42
+ def __init__ (self , splits ):
43
+ self .tuple_cls = namedtuple ('datasets' , splits )
44
+ self .tuple = self .tuple_cls (* [None for _ in splits ])
45
+
46
+ def __getitem__ (self , key ):
47
+ if isinstance (key , (int , slice )):
48
+ return self .tuple [key ]
49
+ if isinstance (key , str ):
50
+ return getattr (self .tuple , key )
51
+
52
+ def __repr__ (self ):
53
+ return self .tuple .__repr__ ()
54
+
55
+ def __setitem__ (self , key , value ):
56
+ self .tuple = self .tuple ._replace (** {key : value })
57
+
58
+ def __len__ (self ):
59
+ return len (self .tuple )
60
+
61
+
40
62
def import_main_class (module_path ):
41
63
"""
42
64
Import a module at module_path and return its DatasetBuilder class.
@@ -58,6 +80,40 @@ def import_main_class(module_path):
58
80
return module_main_cls
59
81
60
82
83
+ def load_from_hf (path , name = None , splits = None , ** kwargs ):
84
+ from datasets import load_dataset as load_hf_dataset
85
+ from datasets import DatasetDict
86
+ from datasets .features import ClassLabel
87
+ try :
88
+ hf_datasets = load_hf_dataset (path , name = name , split = splits , ** kwargs )
89
+ except FileNotFoundError :
90
+ raise FileNotFoundError ("Couldn't find the dataset script for '" + path
91
+ + "' on PaddleNLP or HuggingFace" )
92
+ else :
93
+ label_list = []
94
+ if isinstance (hf_datasets , DatasetDict ):
95
+ datasets = DatasetTuple (hf_datasets .keys ())
96
+ for split , ds in hf_datasets .items ():
97
+ for feature in ds .features .values ():
98
+ if isinstance (feature , ClassLabel ):
99
+ label_list = feature .names
100
+ datasets [split ] = MapDataset (ds , label_list = label_list )
101
+ elif isinstance (hf_datasets , list ):
102
+ datasets = DatasetTuple (splits )
103
+ for i , split in enumerate (splits ):
104
+ for feature in hf_datasets [i ].features .values ():
105
+ if isinstance (feature , ClassLabel ):
106
+ label_list = feature .names
107
+ datasets [split ] = MapDataset (
108
+ hf_datasets [i ], label_list = label_list )
109
+ else :
110
+ for feature in hf_datasets .features .values ():
111
+ if isinstance (feature , ClassLabel ):
112
+ label_list = feature .names
113
+ datasets = MapDataset (hf_datasets , label_list = label_list )
114
+ return datasets
115
+
116
+
61
117
def load_dataset (path_or_read_func ,
62
118
name = None ,
63
119
data_files = None ,
@@ -109,37 +165,43 @@ def load_dataset(path_or_read_func,
109
165
reader_instance = SimpleBuilder (lazy = lazy , read_func = path_or_read_func )
110
166
return reader_instance .read (** custom_kwargs )
111
167
else :
112
- reader_cls = import_main_class (path_or_read_func )
113
- reader_instance = reader_cls (lazy = lazy , name = name , ** kwargs )
168
+ try :
169
+ reader_cls = import_main_class (path_or_read_func )
170
+ except ModuleNotFoundError :
171
+ datasets = load_from_hf (
172
+ path_or_read_func , name = name , splits = splits , ** kwargs )
173
+ else :
174
+ reader_instance = reader_cls (lazy = lazy , name = name , ** kwargs )
114
175
115
- # Check if selected name and split is valid in this DatasetBuilder
116
- if hasattr (reader_instance , 'BUILDER_CONFIGS' ):
117
- if name in reader_cls .BUILDER_CONFIGS .keys ():
118
- split_names = reader_cls .BUILDER_CONFIGS [name ]['splits' ].keys ()
176
+ # Check if selected name and split is valid in this DatasetBuilder
177
+ if hasattr (reader_instance , 'BUILDER_CONFIGS' ):
178
+ if name in reader_cls .BUILDER_CONFIGS .keys ():
179
+ split_names = reader_cls .BUILDER_CONFIGS [name ][
180
+ 'splits' ].keys ()
181
+ else :
182
+ raise ValueError (
183
+ 'Invalid name "{}". Should be one of {}.' .format (
184
+ name , list (reader_cls .BUILDER_CONFIGS .keys ())))
185
+ elif hasattr (reader_instance , 'SPLITS' ):
186
+ split_names = reader_instance .SPLITS .keys ()
119
187
else :
120
- raise ValueError (
121
- 'Invalid name "{}". Should be one of {}.' .format (
122
- name , list (reader_cls .BUILDER_CONFIGS .keys ())))
123
- elif hasattr (reader_instance , 'SPLITS' ):
124
- split_names = reader_instance .SPLITS .keys ()
125
- else :
126
- raise AttributeError (
127
- "Either 'SPLITS' or 'BUILDER_CONFIGS' must be implemented for DatasetBuilder."
128
- )
188
+ raise AttributeError (
189
+ "Either 'SPLITS' or 'BUILDER_CONFIGS' must be implemented for DatasetBuilder."
190
+ )
129
191
130
- selected_splits = []
131
- if isinstance (splits , list ) or isinstance (splits , tuple ):
132
- selected_splits .extend (splits )
133
- else :
134
- selected_splits += [splits ]
192
+ selected_splits = []
193
+ if isinstance (splits , list ) or isinstance (splits , tuple ):
194
+ selected_splits .extend (splits )
195
+ else :
196
+ selected_splits += [splits ]
135
197
136
- for split_name in selected_splits :
137
- if split_name not in split_names and split_name != None :
138
- raise ValueError ('Invalid split "{}". Should be one of {}.' .
139
- format (split_name , list (split_names )))
198
+ for split_name in selected_splits :
199
+ if split_name not in split_names and split_name != None :
200
+ raise ValueError ('Invalid split "{}". Should be one of {}.' .
201
+ format (split_name , list (split_names )))
140
202
141
- datasets = reader_instance .read_datasets (
142
- data_files = data_files , splits = splits )
203
+ datasets = reader_instance .read_datasets (
204
+ data_files = data_files , splits = splits )
143
205
return datasets
144
206
145
207
@@ -163,9 +225,9 @@ def __init__(self, data, **kwargs):
163
225
self .data = data
164
226
self ._transform_pipline = []
165
227
self .new_data = self .data
166
-
167
- self .label_list = kwargs .pop ('label_list' , None )
168
- self .vocab_info = kwargs .pop ('vocab_info' , None )
228
+ self . info = kwargs
229
+ self .label_list = self . info .pop ('label_list' , None )
230
+ self .vocab_info = self . info .pop ('vocab_info' , None )
169
231
170
232
def _transform (self , data ):
171
233
for fn in self ._transform_pipline :
@@ -198,23 +260,22 @@ def filter(self, fn, num_workers=0):
198
260
set to 0, it doesn't use multiprocessing. Defaults to `0`.
199
261
"""
200
262
assert num_workers >= 0 , "num_workers should be a non-negative value"
201
- if num_workers > 0 :
202
- pool = Pool (
203
- num_workers , initargs = (RLock (), ), maxtasksperchild = 1000 )
204
-
205
- def filter_shard (num_workers , index , fn ):
206
- self .shard (num_shards = num_workers , index = index , contiguous = True )
207
- self ._filter (fn = fn )
208
- return self
209
-
263
+ if num_workers > 1 :
264
+ shards = [
265
+ self ._shard (
266
+ num_shards = num_workers , index = index , contiguous = True )
267
+ for index in range (num_workers )
268
+ ]
210
269
kwds_per_shard = [
211
270
dict (
212
- num_workers = num_workers , index = rank , fn = fn )
213
- for rank in range (num_workers )
271
+ self = shards [rank ], fn = fn ) for rank in range (num_workers )
214
272
]
273
+ pool = Pool (num_workers , initargs = (RLock (), ))
274
+
215
275
results = [
216
276
pool .apply_async (
217
- filter_shard , kwds = kwds ) for kwds in kwds_per_shard
277
+ self .__class__ ._filter , kwds = kwds )
278
+ for kwds in kwds_per_shard
218
279
]
219
280
transformed_shards = [r .get () for r in results ]
220
281
@@ -235,6 +296,11 @@ def _filter(self, fn):
235
296
return self
236
297
237
298
def shard (self , num_shards = None , index = None , contiguous = False ):
299
+ self .new_data = self ._shard (
300
+ num_shards = num_shards , index = index , contiguous = contiguous ).data
301
+ return self
302
+
303
+ def _shard (self , num_shards = None , index = None , contiguous = False ):
238
304
"""
239
305
Split the dataset into `num_shards` pieces. Note that the size of each
240
306
shard might be different because the original dataset may not be evenly
@@ -262,15 +328,14 @@ def shard(self, num_shards=None, index=None, contiguous=False):
262
328
mod = len (self ) % num_shards
263
329
start = div * index + min (index , mod )
264
330
end = start + div + (1 if index < mod else 0 )
265
- self . new_data = self .new_data [start : end ]
331
+ new_data = [ self .new_data [idx ] for idx in range ( start , end ) ]
266
332
else :
267
- num_samples = int (math .ceil (len (self .new_data ) * 1.0 / num_shards ))
268
- self .new_data = [
333
+ new_data = [
269
334
self .new_data [idx ] for idx in range (len (self .new_data ))
270
335
if idx % num_shards == index
271
336
]
272
337
273
- return self
338
+ return MapDataset ( new_data )
274
339
275
340
def map (self , fn , lazy = True , batched = False , num_workers = 0 ):
276
341
"""
@@ -292,25 +357,22 @@ def map(self, fn, lazy=True, batched=False, num_workers=0):
292
357
"""
293
358
294
359
assert num_workers >= 0 , "num_workers should be a non-negative value"
295
- if num_workers > 0 :
296
-
297
- def map_shard (num_workers , index , fn , batched ):
298
- self .shard (num_shards = num_workers , index = index , contiguous = True )
299
- self ._map (fn = fn , lazy = False , batched = batched )
300
- return self
301
-
360
+ if num_workers > 1 :
361
+ shards = [
362
+ self ._shard (
363
+ num_shards = num_workers , index = index , contiguous = True )
364
+ for index in range (num_workers )
365
+ ]
302
366
kwds_per_shard = [
303
367
dict (
304
- num_workers = num_workers , index = rank , fn = fn , batched = batched )
368
+ self = shards [ rank ] , fn = fn , lazy = False , batched = batched )
305
369
for rank in range (num_workers )
306
370
]
307
- pool = Pool (
308
- num_workers , initargs = (RLock (), ), maxtasksperchild = 1000 )
371
+ pool = Pool (num_workers , initargs = (RLock (), ))
309
372
results = [
310
373
pool .apply_async (
311
- map_shard , kwds = kwds ) for kwds in kwds_per_shard
374
+ self . __class__ . _map , kwds = kwds ) for kwds in kwds_per_shard
312
375
]
313
-
314
376
transformed_shards = [r .get () for r in results ]
315
377
pool .close ()
316
378
pool .join ()
@@ -471,9 +533,6 @@ def __init__(self, lazy=None, name=None, **config):
471
533
self .config = config
472
534
473
535
def read_datasets (self , splits = None , data_files = None ):
474
- datasets = []
475
- assert splits or data_files , "`data_files` and `splits` can not both be None."
476
-
477
536
def remove_if_exit (filepath ):
478
537
if isinstance (filepath , (list , tuple )):
479
538
for file in filepath :
@@ -487,14 +546,21 @@ def remove_if_exit(filepath):
487
546
except OSError :
488
547
pass
489
548
490
- if splits and data_files is None :
549
+ if data_files is None :
550
+ if splits is None :
551
+ splits = list (self .BUILDER_CONFIGS [self .name ]['splits' ].keys (
552
+ )) if hasattr (self ,
553
+ "BUILDER_CONFIGS" ) else list (self .SPLITS .keys ())
554
+
491
555
assert isinstance (splits , str ) or (
492
556
isinstance (splits , list ) and isinstance (splits [0 ], str )
493
557
) or (
494
558
isinstance (splits , tuple ) and isinstance (splits [0 ], str )
495
559
), "`splits` should be a string or list of string or a tuple of string."
560
+
496
561
if isinstance (splits , str ):
497
562
splits = [splits ]
563
+ datasets = DatasetTuple (splits )
498
564
parallel_env = dist .ParallelEnv ()
499
565
unique_endpoints = _get_unique_endpoints (
500
566
parallel_env .trainer_endpoints [:])
@@ -526,34 +592,31 @@ def remove_if_exit(filepath):
526
592
else :
527
593
while not os .path .exists (lock_file ):
528
594
time .sleep (1 )
529
- datasets .append (self .read (filename = filename , split = split ))
530
-
531
- if data_files :
595
+ datasets [split ] = self .read (filename = filename , split = split )
596
+ else :
532
597
assert isinstance (data_files , str ) or isinstance (
533
598
data_files , tuple ) or isinstance (
534
599
data_files , list
535
600
), "`data_files` should be a string or tuple or list of strings."
536
-
537
601
if isinstance (data_files , str ):
538
602
data_files = [data_files ]
539
603
default_split = 'train'
540
604
if splits :
541
605
if isinstance (splits , str ):
542
606
splits = [splits ]
607
+ datasets = DatasetTuple (splits )
543
608
assert len (splits ) == len (
544
609
data_files
545
610
), "Number of `splits` and number of `data_files` should be the same if you want to specify the split of loacl data file."
546
- datasets += [
547
- self .read (
611
+ for i in range ( len ( data_files )):
612
+ datasets [ splits [ i ]] = self .read (
548
613
filename = data_files [i ], split = splits [i ])
549
- for i in range (len (data_files ))
550
- ]
551
614
else :
552
- datasets += [
553
- self .read (
615
+ datasets = DatasetTuple (
616
+ ["split" + str (i ) for i in range (len (data_files ))])
617
+ for i in range (len (data_files )):
618
+ datasets ["split" + str (i )] = self .read (
554
619
filename = data_files [i ], split = default_split )
555
- for i in range (len (data_files ))
556
- ]
557
620
558
621
return datasets if len (datasets ) > 1 else datasets [0 ]
559
622
0 commit comments