Skip to content

Commit 497aa1b

Browse files
committed
Sorry, last change was meant to a PR. This reverts commit d0fcf41.
1 parent d0fcf41 commit 497aa1b

File tree

1 file changed

+26
-53
lines changed

1 file changed

+26
-53
lines changed

tools/preprocess_data_many_cores.py

Lines changed: 26 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@
3535
import sys
3636
import threading
3737
import time
38-
39-
import datasets
4038
import torch
4139
from multiprocessing.connection import Connection
4240

@@ -73,7 +71,7 @@ def tokenize(self, *text):
7371

7472
class Encoder(object):
7573
def __init__(self, args):
76-
self.content_keys = args.content_keys
74+
self.json_keys = args.json_keys
7775
self.append_eod = args.append_eod
7876
# Use Encoder class as a container for global data
7977
self.tokenizer = build_tokenizer(args)
@@ -93,14 +91,11 @@ def __init__(self, args):
9391
else:
9492
self.splitter = IdentitySplitter()
9593

96-
def encode(self, data):
94+
def encode(self, json_line):
95+
data = json.loads(json_line)
9796
ids = {}
98-
# TODO: a character is not a byte for non-ascii scripts. this was like this before, maybe fix at some point
99-
# (counting the actual bytes will slow down processing though)
100-
bytes = 0
101-
for key in self.content_keys:
97+
for key in self.json_keys:
10298
text = data[key]
103-
bytes += len(text)
10499
doc_ids = []
105100
for sentence in self.splitter.tokenize(text):
106101
sentence_ids = self.tokenizer.tokenize(sentence)
@@ -109,7 +104,7 @@ def encode(self, data):
109104
if len(doc_ids) > 0 and self.append_eod:
110105
doc_ids[-1].append(self.tokenizer.eod)
111106
ids[key] = doc_ids
112-
return ids, bytes
107+
return ids, len(json_line)
113108

114109

115110
def process_samples(simple_queue, process_id, args, level, writer: Connection):
@@ -118,7 +113,7 @@ def process_samples(simple_queue, process_id, args, level, writer: Connection):
118113
output_bin_files = {}
119114
output_idx_files = {}
120115
builders = {}
121-
for key in args.content_keys:
116+
for key in args.json_keys:
122117
output_filename = get_output_filename(args.output_prefix, key, level, process_id)
123118
output_bin_files[key] = data_file_path(output_filename)
124119
output_idx_files[key] = index_file_path(output_filename)
@@ -127,38 +122,33 @@ def process_samples(simple_queue, process_id, args, level, writer: Connection):
127122
impl=args.dataset_impl,
128123
dtype=best_dtype)
129124

130-
doc_lines = simple_queue.get()
131-
while doc_lines is not None:
132-
process_lines(doc_lines, encoder, builders, writer)
125+
json_lines = simple_queue.get()
126+
while json_lines is not None:
127+
process_json_lines(json_lines, encoder, builders, writer)
133128

134-
doc_lines = simple_queue.get()
129+
json_lines = simple_queue.get()
135130

136131
# In case finished, we still need to add None to signal to everyone else
137132
simple_queue.put(None)
138133
# Send None as end of sequence signal
139134
writer.send((None, process_id))
140135
writer.close()
141136

142-
for key in args.content_keys:
137+
for key in args.json_keys:
143138
builders[key].finalize(output_idx_files[key])
144139

145140
print(f"Worker {process_id} finished", flush=True)
146141

147142

148-
def process_lines(lines, encoder, builders, writer):
143+
def process_json_lines(json_lines, encoder, builders, writer):
149144
total_bytes_processed = 0
150-
for line in lines:
145+
for json_line in json_lines:
146+
if json_line.strip() == "":
147+
continue
151148

152-
if isinstance(line, str):
153-
if line.strip() == "":
154-
continue
155-
data = json.loads(line)
156-
doc, bytes_processed = encoder.encode(data)
157-
total_bytes_processed += bytes_processed
149+
doc, bytes_processed = encoder.encode(json_line)
158150

159-
elif isinstance(line, dict):
160-
doc, bytes_processed = encoder.encode(line)
161-
total_bytes_processed += bytes_processed
151+
total_bytes_processed += bytes_processed
162152

163153
for key, sentences in doc.items():
164154
if len(sentences) == 0:
@@ -167,16 +157,16 @@ def process_lines(lines, encoder, builders, writer):
167157
builders[key].add_item(torch.IntTensor(sentence))
168158
builders[key].end_document()
169159

170-
writer.send((len(lines), total_bytes_processed))
160+
writer.send((len(json_lines), total_bytes_processed))
171161

172162

173163
def get_args():
174164
parser = argparse.ArgumentParser()
175165
group = parser.add_argument_group(title='input data')
176166
group.add_argument('--input', type=str, required=True,
177-
help='Path to input JSON or arrow file')
178-
group.add_argument('--content-keys', nargs='+', default=['text'],
179-
help='space separate listed of keys to extract from data')
167+
help='Path to input JSON')
168+
group.add_argument('--json-keys', nargs='+', default=['text'],
169+
help='space separate listed of keys to extract from json')
180170
group.add_argument('--split-sentences', action='store_true',
181171
help='Split documents into sentences.')
182172
group.add_argument('--keep-newlines', action='store_true',
@@ -229,7 +219,7 @@ def get_args():
229219

230220
return args
231221

232-
def fill_simple_queue_from_file(filename, simple_queue, chunk_size:int):
222+
def fill_simple_queue(filename, simple_queue, chunk_size:int):
233223
# TODO: Assess if instead we could feed pointers which process can then load.
234224
with open(filename, "r") as f:
235225
print("Start filling queue", flush=True)
@@ -241,18 +231,6 @@ def fill_simple_queue_from_file(filename, simple_queue, chunk_size:int):
241231
return
242232
simple_queue.put(acc)
243233

244-
def fill_simple_queue_from_arrow(dirname, simple_queue, chunk_size:int):
245-
# TODO: Assess if instead we could feed pointers which process can then load.
246-
dataset = datasets.load_from_disk(dirname)
247-
print("Start filling queue", flush=True)
248-
while True:
249-
acc = tuple(itertools.islice(dataset, chunk_size))
250-
if len(acc) == 0:
251-
simple_queue.put(None)
252-
print(f"Finished reading input file", flush=True)
253-
return
254-
simple_queue.put(acc)
255-
256234
def log(readers, log_interval):
257235
print("Start Logging", flush=True)
258236
proc_start = time.time()
@@ -322,12 +300,7 @@ def main():
322300
process_ids = list(range(len(writers)))
323301
processes = [multiprocessing.Process(target=process_samples, args=(simple_queue, process_id, args, level, writer)) for process_id, writer in zip(process_ids, writers)]
324302
log_thread = threading.Thread(target=log, args=(list(readers), args.log_interval))
325-
if os.path.isfile(args.input):
326-
print("assuming `jsonl` input.")
327-
fill_thread = multiprocessing.Process(target=fill_simple_queue_from_file, args=(args.input, simple_queue, chunk_size))
328-
elif os.path.isdir(args.input):
329-
print("assuming arrow folder input for HF-datasets")
330-
fill_thread = multiprocessing.Process(target=fill_simple_queue_from_arrow, args=(args.input, simple_queue, chunk_size))
303+
fill_thread = multiprocessing.Process(target=fill_simple_queue, args=(args.input, simple_queue, chunk_size))
331304

332305
fill_thread.start()
333306
log_thread.start()
@@ -359,7 +332,7 @@ def main():
359332
output_bin_files = {}
360333
output_idx_files = {}
361334
builders = {}
362-
for key in args.content_keys:
335+
for key in args.json_keys:
363336
output_filename = f"{args.output_prefix}_{key}_{level}"
364337
output_bin_files[key] = data_file_path(output_filename)
365338
output_idx_files[key] = index_file_path(output_filename)
@@ -368,15 +341,15 @@ def main():
368341
impl=args.dataset_impl,
369342
dtype=best_dtype)
370343

371-
for key in args.content_keys:
344+
for key in args.json_keys:
372345
for process_id in process_ids:
373346
output_filename = get_output_filename(args.output_prefix, key, level, process_id)
374347
builders[key].merge_file_(output_filename)
375348
builders[key].finalize(output_idx_files[key])
376349

377350
# Remove temporary files
378351
print("Removing shard files")
379-
for key in args.content_keys:
352+
for key in args.json_keys:
380353
for process_id in process_ids:
381354
output_filename = get_output_filename(args.output_prefix, key, level, process_id)
382355
os.remove(data_file_path(output_filename))

0 commit comments

Comments
 (0)