forked from LAIR-RCC/ruadapt
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathextend_or_replace.py
More file actions
95 lines (74 loc) · 3.75 KB
/
extend_or_replace.py
File metadata and controls
95 lines (74 loc) · 3.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import argparse
import os
import subprocess
import shutil
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--op')
parser.add_argument('--src_model_path')
parser.add_argument('--output_path')
parser.add_argument('--replace_tokenizer_path', default='')
parser.add_argument('--extend_tiktoken_tokenizer_path', default='')
parser.add_argument('--extend_hf_tokenizer_path', default='')
parser.add_argument('--extend_hf_tokenizer_type', default='')
parser.add_argument('--only_ru', action='store_true')
parser.add_argument('--filter_numbers', action='store_true')
parser.add_argument('--custom_tokens_path', default=None)
parser.add_argument('--init_mode', default='mean')
parser.add_argument('--mult', default=1.0, type=float)
args = parser.parse_args()
assert args.op in ['replace', 'extend']
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
if args.op == 'replace':
subprocess.call(['python', '-m', 'ruadapt.tokenization.run_replace_tokenizer',
'--model_name_or_path', args.src_model_path,
'--new_tokenizer_path', args.replace_tokenizer_path,
'--output_path', args.output_path])
elif args.op == 'extend':
sp_tok_vocab_freq_path = os.path.join(args.output_path, 'sp_tok_vocab_freq.txt')
call_params = ['python', '-m', 'ruadapt.tokenization.convert_hf_tokenizer_vocab_to_freq_list',
'--tokenizer_path', args.extend_hf_tokenizer_path,
'--output_path', sp_tok_vocab_freq_path,
'--type', args.extend_hf_tokenizer_type]
if args.custom_tokens_path is not None:
call_params += ['--custom_tokens_path', args.custom_tokens_path]
if args.only_ru:
call_params.append('--only_ru')
call_res = subprocess.call(call_params)
if call_res != 0:
print(call_res)
print('ERROR. Stoping pipeline')
exit(1)
tokenizer_extended_part_path = os.path.join(args.output_path, 'tokenizer_extended_part.tiktoken')
tiktoken_base_path = os.path.join(args.output_path, 'tokenizer_base.tiktoken')
shutil.copyfile(args.extend_tiktoken_tokenizer_path, tiktoken_base_path)
call_res = subprocess.call(
['python', '-m', 'ruadapt.tokenization.add_merges',
'--input_path', tiktoken_base_path,
'--output_path', tokenizer_extended_part_path,
'--vocab_path', sp_tok_vocab_freq_path,
'--start_id', str(-1)])
if call_res != 0:
print(call_res)
print('ERROR. Stoping pipeline')
exit(1)
call_params = ['python', '-m', 'ruadapt.tokenization.expand_tiktoken_save_hf',
'--tiktoken_base_path', tiktoken_base_path,
'--tiktoken_new_path', tokenizer_extended_part_path,
'--output_dir', os.path.join(args.output_path, 'hf_tokenizer'),
'--init_output_from', args.src_model_path]
if args.filter_numbers:
call_params.append('--filter_numbers')
call_res = subprocess.call(call_params)
if call_res != 0:
print(call_res)
print('ERROR. Stoping pipeline')
exit(1)
subprocess.call(
['python', '-m', 'ruadapt.tokenization.run_replace_tokenizer',
'--model_name_or_path', args.src_model_path,
'--new_tokenizer_path', os.path.join(args.output_path, 'hf_tokenizer'),
'--output_path', args.output_path,
'--mode', args.init_mode,
'--mult', str(args.mult)])