3535import sys
3636import threading
3737import time
38-
39- import datasets
4038import torch
4139from multiprocessing .connection import Connection
4240
@@ -73,7 +71,7 @@ def tokenize(self, *text):
7371
7472class Encoder (object ):
7573 def __init__ (self , args ):
76- self .content_keys = args .content_keys
74+ self .json_keys = args .json_keys
7775 self .append_eod = args .append_eod
7876 # Use Encoder class as a container for global data
7977 self .tokenizer = build_tokenizer (args )
@@ -93,14 +91,11 @@ def __init__(self, args):
9391 else :
9492 self .splitter = IdentitySplitter ()
9593
96- def encode (self , data ):
94+ def encode (self , json_line ):
95+ data = json .loads (json_line )
9796 ids = {}
98- # TODO: a character is not a byte for non-ascii scripts. this was like this before, maybe fix at some point
99- # (counting the actual bytes will slow down processing though)
100- bytes = 0
101- for key in self .content_keys :
97+ for key in self .json_keys :
10298 text = data [key ]
103- bytes += len (text )
10499 doc_ids = []
105100 for sentence in self .splitter .tokenize (text ):
106101 sentence_ids = self .tokenizer .tokenize (sentence )
@@ -109,7 +104,7 @@ def encode(self, data):
109104 if len (doc_ids ) > 0 and self .append_eod :
110105 doc_ids [- 1 ].append (self .tokenizer .eod )
111106 ids [key ] = doc_ids
112- return ids , bytes
107+ return ids , len ( json_line )
113108
114109
115110def process_samples (simple_queue , process_id , args , level , writer : Connection ):
@@ -118,7 +113,7 @@ def process_samples(simple_queue, process_id, args, level, writer: Connection):
118113 output_bin_files = {}
119114 output_idx_files = {}
120115 builders = {}
121- for key in args .content_keys :
116+ for key in args .json_keys :
122117 output_filename = get_output_filename (args .output_prefix , key , level , process_id )
123118 output_bin_files [key ] = data_file_path (output_filename )
124119 output_idx_files [key ] = index_file_path (output_filename )
@@ -127,38 +122,33 @@ def process_samples(simple_queue, process_id, args, level, writer: Connection):
127122 impl = args .dataset_impl ,
128123 dtype = best_dtype )
129124
130- doc_lines = simple_queue .get ()
131- while doc_lines is not None :
132- process_lines ( doc_lines , encoder , builders , writer )
125+ json_lines = simple_queue .get ()
126+ while json_lines is not None :
127+ process_json_lines ( json_lines , encoder , builders , writer )
133128
134- doc_lines = simple_queue .get ()
129+ json_lines = simple_queue .get ()
135130
136131 # In case finished, we still need to add None to signal to everyone else
137132 simple_queue .put (None )
138133 # Send None as end of sequence signal
139134 writer .send ((None , process_id ))
140135 writer .close ()
141136
142- for key in args .content_keys :
137+ for key in args .json_keys :
143138 builders [key ].finalize (output_idx_files [key ])
144139
145140 print (f"Worker { process_id } finished" , flush = True )
146141
147142
148- def process_lines ( lines , encoder , builders , writer ):
143+ def process_json_lines ( json_lines , encoder , builders , writer ):
149144 total_bytes_processed = 0
150- for line in lines :
145+ for json_line in json_lines :
146+ if json_line .strip () == "" :
147+ continue
151148
152- if isinstance (line , str ):
153- if line .strip () == "" :
154- continue
155- data = json .loads (line )
156- doc , bytes_processed = encoder .encode (data )
157- total_bytes_processed += bytes_processed
149+ doc , bytes_processed = encoder .encode (json_line )
158150
159- elif isinstance (line , dict ):
160- doc , bytes_processed = encoder .encode (line )
161- total_bytes_processed += bytes_processed
151+ total_bytes_processed += bytes_processed
162152
163153 for key , sentences in doc .items ():
164154 if len (sentences ) == 0 :
@@ -167,16 +157,16 @@ def process_lines(lines, encoder, builders, writer):
167157 builders [key ].add_item (torch .IntTensor (sentence ))
168158 builders [key ].end_document ()
169159
170- writer .send ((len (lines ), total_bytes_processed ))
160+ writer .send ((len (json_lines ), total_bytes_processed ))
171161
172162
173163def get_args ():
174164 parser = argparse .ArgumentParser ()
175165 group = parser .add_argument_group (title = 'input data' )
176166 group .add_argument ('--input' , type = str , required = True ,
177- help = 'Path to input JSON or arrow file ' )
178- group .add_argument ('--content -keys' , nargs = '+' , default = ['text' ],
179- help = 'space separate listed of keys to extract from data ' )
167+ help = 'Path to input JSON' )
168+ group .add_argument ('--json -keys' , nargs = '+' , default = ['text' ],
169+ help = 'space separate listed of keys to extract from json ' )
180170 group .add_argument ('--split-sentences' , action = 'store_true' ,
181171 help = 'Split documents into sentences.' )
182172 group .add_argument ('--keep-newlines' , action = 'store_true' ,
@@ -229,7 +219,7 @@ def get_args():
229219
230220 return args
231221
232- def fill_simple_queue_from_file (filename , simple_queue , chunk_size :int ):
222+ def fill_simple_queue (filename , simple_queue , chunk_size :int ):
233223 # TODO: Assess if instead we could feed pointers which process can then load.
234224 with open (filename , "r" ) as f :
235225 print ("Start filling queue" , flush = True )
@@ -241,18 +231,6 @@ def fill_simple_queue_from_file(filename, simple_queue, chunk_size:int):
241231 return
242232 simple_queue .put (acc )
243233
244- def fill_simple_queue_from_arrow (dirname , simple_queue , chunk_size :int ):
245- # TODO: Assess if instead we could feed pointers which process can then load.
246- dataset = datasets .load_from_disk (dirname )
247- print ("Start filling queue" , flush = True )
248- while True :
249- acc = tuple (itertools .islice (dataset , chunk_size ))
250- if len (acc ) == 0 :
251- simple_queue .put (None )
252- print (f"Finished reading input file" , flush = True )
253- return
254- simple_queue .put (acc )
255-
256234def log (readers , log_interval ):
257235 print ("Start Logging" , flush = True )
258236 proc_start = time .time ()
@@ -322,12 +300,7 @@ def main():
322300 process_ids = list (range (len (writers )))
323301 processes = [multiprocessing .Process (target = process_samples , args = (simple_queue , process_id , args , level , writer )) for process_id , writer in zip (process_ids , writers )]
324302 log_thread = threading .Thread (target = log , args = (list (readers ), args .log_interval ))
325- if os .path .isfile (args .input ):
326- print ("assuming `jsonl` input." )
327- fill_thread = multiprocessing .Process (target = fill_simple_queue_from_file , args = (args .input , simple_queue , chunk_size ))
328- elif os .path .isdir (args .input ):
329- print ("assuming arrow folder input for HF-datasets" )
330- fill_thread = multiprocessing .Process (target = fill_simple_queue_from_arrow , args = (args .input , simple_queue , chunk_size ))
303+ fill_thread = multiprocessing .Process (target = fill_simple_queue , args = (args .input , simple_queue , chunk_size ))
331304
332305 fill_thread .start ()
333306 log_thread .start ()
@@ -359,7 +332,7 @@ def main():
359332 output_bin_files = {}
360333 output_idx_files = {}
361334 builders = {}
362- for key in args .content_keys :
335+ for key in args .json_keys :
363336 output_filename = f"{ args .output_prefix } _{ key } _{ level } "
364337 output_bin_files [key ] = data_file_path (output_filename )
365338 output_idx_files [key ] = index_file_path (output_filename )
@@ -368,15 +341,15 @@ def main():
368341 impl = args .dataset_impl ,
369342 dtype = best_dtype )
370343
371- for key in args .content_keys :
344+ for key in args .json_keys :
372345 for process_id in process_ids :
373346 output_filename = get_output_filename (args .output_prefix , key , level , process_id )
374347 builders [key ].merge_file_ (output_filename )
375348 builders [key ].finalize (output_idx_files [key ])
376349
377350 # Remove temporary files
378351 print ("Removing shard files" )
379- for key in args .content_keys :
352+ for key in args .json_keys :
380353 for process_id in process_ids :
381354 output_filename = get_output_filename (args .output_prefix , key , level , process_id )
382355 os .remove (data_file_path (output_filename ))
0 commit comments