Skip to content

Commit d0fcf41

Browse files
committed
preprocessing from arrow file to load an HF dataset
1 parent b9883f4 commit d0fcf41

File tree

1 file changed

+53
-26
lines changed

1 file changed

+53
-26
lines changed

tools/preprocess_data_many_cores.py

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
import sys
3636
import threading
3737
import time
38+
39+
import datasets
3840
import torch
3941
from multiprocessing.connection import Connection
4042

@@ -71,7 +73,7 @@ def tokenize(self, *text):
7173

7274
class Encoder(object):
7375
def __init__(self, args):
74-
self.json_keys = args.json_keys
76+
self.content_keys = args.content_keys
7577
self.append_eod = args.append_eod
7678
# Use Encoder class as a container for global data
7779
self.tokenizer = build_tokenizer(args)
@@ -91,11 +93,14 @@ def __init__(self, args):
9193
else:
9294
self.splitter = IdentitySplitter()
9395

94-
def encode(self, json_line):
95-
data = json.loads(json_line)
96+
def encode(self, data):
9697
ids = {}
97-
for key in self.json_keys:
98+
# TODO: a character is not a byte for non-ascii scripts. this was like this before, maybe fix at some point
99+
# (counting the actual bytes will slow down processing though)
100+
bytes = 0
101+
for key in self.content_keys:
98102
text = data[key]
103+
bytes += len(text)
99104
doc_ids = []
100105
for sentence in self.splitter.tokenize(text):
101106
sentence_ids = self.tokenizer.tokenize(sentence)
@@ -104,7 +109,7 @@ def encode(self, json_line):
104109
if len(doc_ids) > 0 and self.append_eod:
105110
doc_ids[-1].append(self.tokenizer.eod)
106111
ids[key] = doc_ids
107-
return ids, len(json_line)
112+
return ids, bytes
108113

109114

110115
def process_samples(simple_queue, process_id, args, level, writer: Connection):
@@ -113,7 +118,7 @@ def process_samples(simple_queue, process_id, args, level, writer: Connection):
113118
output_bin_files = {}
114119
output_idx_files = {}
115120
builders = {}
116-
for key in args.json_keys:
121+
for key in args.content_keys:
117122
output_filename = get_output_filename(args.output_prefix, key, level, process_id)
118123
output_bin_files[key] = data_file_path(output_filename)
119124
output_idx_files[key] = index_file_path(output_filename)
@@ -122,33 +127,38 @@ def process_samples(simple_queue, process_id, args, level, writer: Connection):
122127
impl=args.dataset_impl,
123128
dtype=best_dtype)
124129

125-
json_lines = simple_queue.get()
126-
while json_lines is not None:
127-
process_json_lines(json_lines, encoder, builders, writer)
130+
doc_lines = simple_queue.get()
131+
while doc_lines is not None:
132+
process_lines(doc_lines, encoder, builders, writer)
128133

129-
json_lines = simple_queue.get()
134+
doc_lines = simple_queue.get()
130135

131136
# In case finished, we still need to add None to signal to everyone else
132137
simple_queue.put(None)
133138
# Send None as end of sequence signal
134139
writer.send((None, process_id))
135140
writer.close()
136141

137-
for key in args.json_keys:
142+
for key in args.content_keys:
138143
builders[key].finalize(output_idx_files[key])
139144

140145
print(f"Worker {process_id} finished", flush=True)
141146

142147

143-
def process_json_lines(json_lines, encoder, builders, writer):
148+
def process_lines(lines, encoder, builders, writer):
144149
total_bytes_processed = 0
145-
for json_line in json_lines:
146-
if json_line.strip() == "":
147-
continue
150+
for line in lines:
148151

149-
doc, bytes_processed = encoder.encode(json_line)
152+
if isinstance(line, str):
153+
if line.strip() == "":
154+
continue
155+
data = json.loads(line)
156+
doc, bytes_processed = encoder.encode(data)
157+
total_bytes_processed += bytes_processed
150158

151-
total_bytes_processed += bytes_processed
159+
elif isinstance(line, dict):
160+
doc, bytes_processed = encoder.encode(line)
161+
total_bytes_processed += bytes_processed
152162

153163
for key, sentences in doc.items():
154164
if len(sentences) == 0:
@@ -157,16 +167,16 @@ def process_json_lines(json_lines, encoder, builders, writer):
157167
builders[key].add_item(torch.IntTensor(sentence))
158168
builders[key].end_document()
159169

160-
writer.send((len(json_lines), total_bytes_processed))
170+
writer.send((len(lines), total_bytes_processed))
161171

162172

163173
def get_args():
164174
parser = argparse.ArgumentParser()
165175
group = parser.add_argument_group(title='input data')
166176
group.add_argument('--input', type=str, required=True,
167-
help='Path to input JSON')
168-
group.add_argument('--json-keys', nargs='+', default=['text'],
169-
help='space separate listed of keys to extract from json')
177+
help='Path to input JSON or arrow file')
178+
group.add_argument('--content-keys', nargs='+', default=['text'],
179+
help='space separate listed of keys to extract from data')
170180
group.add_argument('--split-sentences', action='store_true',
171181
help='Split documents into sentences.')
172182
group.add_argument('--keep-newlines', action='store_true',
@@ -219,7 +229,7 @@ def get_args():
219229

220230
return args
221231

222-
def fill_simple_queue(filename, simple_queue, chunk_size:int):
232+
def fill_simple_queue_from_file(filename, simple_queue, chunk_size:int):
223233
# TODO: Assess if instead we could feed pointers which process can then load.
224234
with open(filename, "r") as f:
225235
print("Start filling queue", flush=True)
@@ -231,6 +241,18 @@ def fill_simple_queue(filename, simple_queue, chunk_size:int):
231241
return
232242
simple_queue.put(acc)
233243

244+
def fill_simple_queue_from_arrow(dirname, simple_queue, chunk_size:int):
245+
# TODO: Assess if instead we could feed pointers which process can then load.
246+
dataset = datasets.load_from_disk(dirname)
247+
print("Start filling queue", flush=True)
248+
while True:
249+
acc = tuple(itertools.islice(dataset, chunk_size))
250+
if len(acc) == 0:
251+
simple_queue.put(None)
252+
print(f"Finished reading input file", flush=True)
253+
return
254+
simple_queue.put(acc)
255+
234256
def log(readers, log_interval):
235257
print("Start Logging", flush=True)
236258
proc_start = time.time()
@@ -300,7 +322,12 @@ def main():
300322
process_ids = list(range(len(writers)))
301323
processes = [multiprocessing.Process(target=process_samples, args=(simple_queue, process_id, args, level, writer)) for process_id, writer in zip(process_ids, writers)]
302324
log_thread = threading.Thread(target=log, args=(list(readers), args.log_interval))
303-
fill_thread = multiprocessing.Process(target=fill_simple_queue, args=(args.input, simple_queue, chunk_size))
325+
if os.path.isfile(args.input):
326+
print("assuming `jsonl` input.")
327+
fill_thread = multiprocessing.Process(target=fill_simple_queue_from_file, args=(args.input, simple_queue, chunk_size))
328+
elif os.path.isdir(args.input):
329+
print("assuming arrow folder input for HF-datasets")
330+
fill_thread = multiprocessing.Process(target=fill_simple_queue_from_arrow, args=(args.input, simple_queue, chunk_size))
304331

305332
fill_thread.start()
306333
log_thread.start()
@@ -332,7 +359,7 @@ def main():
332359
output_bin_files = {}
333360
output_idx_files = {}
334361
builders = {}
335-
for key in args.json_keys:
362+
for key in args.content_keys:
336363
output_filename = f"{args.output_prefix}_{key}_{level}"
337364
output_bin_files[key] = data_file_path(output_filename)
338365
output_idx_files[key] = index_file_path(output_filename)
@@ -341,15 +368,15 @@ def main():
341368
impl=args.dataset_impl,
342369
dtype=best_dtype)
343370

344-
for key in args.json_keys:
371+
for key in args.content_keys:
345372
for process_id in process_ids:
346373
output_filename = get_output_filename(args.output_prefix, key, level, process_id)
347374
builders[key].merge_file_(output_filename)
348375
builders[key].finalize(output_idx_files[key])
349376

350377
# Remove temporary files
351378
print("Removing shard files")
352-
for key in args.json_keys:
379+
for key in args.content_keys:
353380
for process_id in process_ids:
354381
output_filename = get_output_filename(args.output_prefix, key, level, process_id)
355382
os.remove(data_file_path(output_filename))

0 commit comments

Comments
 (0)