3535import sys
3636import threading
3737import time
38+
39+ import datasets
3840import torch
3941from multiprocessing .connection import Connection
4042
@@ -71,7 +73,7 @@ def tokenize(self, *text):
7173
7274class Encoder (object ):
7375 def __init__ (self , args ):
74- self .json_keys = args .json_keys
76+ self .content_keys = args .content_keys
7577 self .append_eod = args .append_eod
7678 # Use Encoder class as a container for global data
7779 self .tokenizer = build_tokenizer (args )
@@ -91,11 +93,14 @@ def __init__(self, args):
9193 else :
9294 self .splitter = IdentitySplitter ()
9395
94- def encode (self , json_line ):
95- data = json .loads (json_line )
96+ def encode (self , data ):
9697 ids = {}
97- for key in self .json_keys :
98+ # TODO: a character is not a byte for non-ascii scripts. this was like this before, maybe fix at some point
99+ # (counting the actual bytes will slow down processing though)
100+ bytes = 0
101+ for key in self .content_keys :
98102 text = data [key ]
103+ bytes += len (text )
99104 doc_ids = []
100105 for sentence in self .splitter .tokenize (text ):
101106 sentence_ids = self .tokenizer .tokenize (sentence )
@@ -104,7 +109,7 @@ def encode(self, json_line):
104109 if len (doc_ids ) > 0 and self .append_eod :
105110 doc_ids [- 1 ].append (self .tokenizer .eod )
106111 ids [key ] = doc_ids
107- return ids , len ( json_line )
112+ return ids , bytes
108113
109114
110115def process_samples (simple_queue , process_id , args , level , writer : Connection ):
@@ -113,7 +118,7 @@ def process_samples(simple_queue, process_id, args, level, writer: Connection):
113118 output_bin_files = {}
114119 output_idx_files = {}
115120 builders = {}
116- for key in args .json_keys :
121+ for key in args .content_keys :
117122 output_filename = get_output_filename (args .output_prefix , key , level , process_id )
118123 output_bin_files [key ] = data_file_path (output_filename )
119124 output_idx_files [key ] = index_file_path (output_filename )
@@ -122,33 +127,38 @@ def process_samples(simple_queue, process_id, args, level, writer: Connection):
122127 impl = args .dataset_impl ,
123128 dtype = best_dtype )
124129
125- json_lines = simple_queue .get ()
126- while json_lines is not None :
127- process_json_lines ( json_lines , encoder , builders , writer )
130+ doc_lines = simple_queue .get ()
131+ while doc_lines is not None :
132+ process_lines ( doc_lines , encoder , builders , writer )
128133
129- json_lines = simple_queue .get ()
134+ doc_lines = simple_queue .get ()
130135
131136 # In case finished, we still need to add None to signal to everyone else
132137 simple_queue .put (None )
133138 # Send None as end of sequence signal
134139 writer .send ((None , process_id ))
135140 writer .close ()
136141
137- for key in args .json_keys :
142+ for key in args .content_keys :
138143 builders [key ].finalize (output_idx_files [key ])
139144
140145 print (f"Worker { process_id } finished" , flush = True )
141146
142147
143- def process_json_lines ( json_lines , encoder , builders , writer ):
148+ def process_lines ( lines , encoder , builders , writer ):
144149 total_bytes_processed = 0
145- for json_line in json_lines :
146- if json_line .strip () == "" :
147- continue
150+ for line in lines :
148151
149- doc , bytes_processed = encoder .encode (json_line )
152+ if isinstance (line , str ):
153+ if line .strip () == "" :
154+ continue
155+ data = json .loads (line )
156+ doc , bytes_processed = encoder .encode (data )
157+ total_bytes_processed += bytes_processed
150158
151- total_bytes_processed += bytes_processed
159+ elif isinstance (line , dict ):
160+ doc , bytes_processed = encoder .encode (line )
161+ total_bytes_processed += bytes_processed
152162
153163 for key , sentences in doc .items ():
154164 if len (sentences ) == 0 :
@@ -157,16 +167,16 @@ def process_json_lines(json_lines, encoder, builders, writer):
157167 builders [key ].add_item (torch .IntTensor (sentence ))
158168 builders [key ].end_document ()
159169
160- writer .send ((len (json_lines ), total_bytes_processed ))
170+ writer .send ((len (lines ), total_bytes_processed ))
161171
162172
163173def get_args ():
164174 parser = argparse .ArgumentParser ()
165175 group = parser .add_argument_group (title = 'input data' )
166176 group .add_argument ('--input' , type = str , required = True ,
167- help = 'Path to input JSON' )
168- group .add_argument ('--json -keys' , nargs = '+' , default = ['text' ],
169- help = 'space separate listed of keys to extract from json ' )
177+ help = 'Path to input JSON or arrow file ' )
178+ group .add_argument ('--content -keys' , nargs = '+' , default = ['text' ],
179+ help = 'space separate listed of keys to extract from data ' )
170180 group .add_argument ('--split-sentences' , action = 'store_true' ,
171181 help = 'Split documents into sentences.' )
172182 group .add_argument ('--keep-newlines' , action = 'store_true' ,
@@ -219,7 +229,7 @@ def get_args():
219229
220230 return args
221231
222- def fill_simple_queue (filename , simple_queue , chunk_size :int ):
232+ def fill_simple_queue_from_file (filename , simple_queue , chunk_size :int ):
223233 # TODO: Assess if instead we could feed pointers which process can then load.
224234 with open (filename , "r" ) as f :
225235 print ("Start filling queue" , flush = True )
@@ -231,6 +241,18 @@ def fill_simple_queue(filename, simple_queue, chunk_size:int):
231241 return
232242 simple_queue .put (acc )
233243
244+ def fill_simple_queue_from_arrow (dirname , simple_queue , chunk_size :int ):
245+ # TODO: Assess if instead we could feed pointers which process can then load.
246+ dataset = datasets .load_from_disk (dirname )
247+ print ("Start filling queue" , flush = True )
248+ while True :
249+ acc = tuple (itertools .islice (dataset , chunk_size ))
250+ if len (acc ) == 0 :
251+ simple_queue .put (None )
252+ print (f"Finished reading input file" , flush = True )
253+ return
254+ simple_queue .put (acc )
255+
234256def log (readers , log_interval ):
235257 print ("Start Logging" , flush = True )
236258 proc_start = time .time ()
@@ -300,7 +322,12 @@ def main():
300322 process_ids = list (range (len (writers )))
301323 processes = [multiprocessing .Process (target = process_samples , args = (simple_queue , process_id , args , level , writer )) for process_id , writer in zip (process_ids , writers )]
302324 log_thread = threading .Thread (target = log , args = (list (readers ), args .log_interval ))
303- fill_thread = multiprocessing .Process (target = fill_simple_queue , args = (args .input , simple_queue , chunk_size ))
325+ if os .path .isfile (args .input ):
326+ print ("assuming `jsonl` input." )
327+ fill_thread = multiprocessing .Process (target = fill_simple_queue_from_file , args = (args .input , simple_queue , chunk_size ))
328+ elif os .path .isdir (args .input ):
329+ print ("assuming arrow folder input for HF-datasets" )
330+ fill_thread = multiprocessing .Process (target = fill_simple_queue_from_arrow , args = (args .input , simple_queue , chunk_size ))
304331
305332 fill_thread .start ()
306333 log_thread .start ()
@@ -332,7 +359,7 @@ def main():
332359 output_bin_files = {}
333360 output_idx_files = {}
334361 builders = {}
335- for key in args .json_keys :
362+ for key in args .content_keys :
336363 output_filename = f"{ args .output_prefix } _{ key } _{ level } "
337364 output_bin_files [key ] = data_file_path (output_filename )
338365 output_idx_files [key ] = index_file_path (output_filename )
@@ -341,15 +368,15 @@ def main():
341368 impl = args .dataset_impl ,
342369 dtype = best_dtype )
343370
344- for key in args .json_keys :
371+ for key in args .content_keys :
345372 for process_id in process_ids :
346373 output_filename = get_output_filename (args .output_prefix , key , level , process_id )
347374 builders [key ].merge_file_ (output_filename )
348375 builders [key ].finalize (output_idx_files [key ])
349376
350377 # Remove temporary files
351378 print ("Removing shard files" )
352- for key in args .json_keys :
379+ for key in args .content_keys :
353380 for process_id in process_ids :
354381 output_filename = get_output_filename (args .output_prefix , key , level , process_id )
355382 os .remove (data_file_path (output_filename ))
0 commit comments