diff --git a/experiments/language_model/prepare_data.py b/experiments/language_model/prepare_data.py index 16d47bd..b1edb34 100644 --- a/experiments/language_model/prepare_data.py +++ b/experiments/language_model/prepare_data.py @@ -10,23 +10,20 @@ def tokenize_data(input, output=None, max_seq_length=512): if output is None: output=input + '.spm' all_tokens = [] - with open(input, encoding = 'utf-8') as fs: - for l in tqdm(fs, ncols=80, desc='Loading'): + lines = 0 + wfs = open(output, 'w', encoding = 'utf-8') + with open(input, encoding='utf-8') as fs: + for l in tqdm(fs, ncols=80, desc='processing...'): if len(l) > 0: tokens = tokenizer.tokenize(l) else: tokens = [] all_tokens.extend(tokens) - - print(f'Loaded {len(all_tokens)} tokens from {input}') - lines = 0 - with open(output, 'w', encoding = 'utf-8') as wfs: - idx = 0 - while idx < len(all_tokens): - wfs.write(' '.join(all_tokens[idx:idx+max_seq_length-2]) + '\n') - idx += (max_seq_length - 2) - lines += 1 - + if len(all_tokens) >= max_seq_length-2: + wfs.write(' '.join(all_tokens[:max_seq_length-2]) + '\n') + all_tokens = all_tokens[max_seq_length-2:] + lines += 1 + wfs.close() print(f'Saved {lines} lines to {output}') parser = argparse.ArgumentParser()