-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathpreprocess_prompt.py
More file actions
116 lines (91 loc) · 4.43 KB
/
preprocess_prompt.py
File metadata and controls
116 lines (91 loc) · 4.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# coding=utf-8
# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved.
import copy
import multiprocessing
import os
import sys
from pathlib import Path
import hydra
from mindspeed_rl.config_cls.data_handler_config import DataHandlerConfig
from mindspeed_rl.config_cls.validate_config import validate_data_handler_config
from mindspeed_rl.datasets.indexed_dataset import IndexedDatasetBuilder
from mindspeed_rl.datasets.preprocess_data import merge_datasets, build_splitter, cut_range_to_subs, \
handle_subset
from mindspeed_rl.utils.tokenizer import get_tokenizer
from mindspeed_rl.datasets.data_handler import build_dataset, get_dataset_handler
from mindspeed_rl.utils.loggers import Loggers
logger = Loggers(name="process_data")
cur_file_dir = Path(__file__).absolute().parent
TEMPLATES_DIR = os.path.join(cur_file_dir, "./configs/rlhf/model/templates.json")
config_name = sys.argv.pop(1)
base_dir = os.path.realpath(os.path.join(cur_file_dir, ".."))
def resolve_relative_path(args):
if not os.path.isabs(args.input):
raw_path = os.path.join(base_dir, args.input)
args.input = os.path.realpath(raw_path)
if not args.input.startswith(base_dir):
raise ValueError(f"Invalid path: {args.input} is not within the allowed directory {base_dir} ")
if not os.path.isabs(args.tokenizer_name_or_path):
raw_path = os.path.join(base_dir, args.tokenizer_name_or_path)
args.tokenizer_name_or_path = os.path.realpath(raw_path)
if not args.tokenizer_name_or_path.startswith(base_dir):
raise ValueError(f"Invalid path: {args.tokenizer_name_or_path} is not within the allowed directory {base_dir} ")
if not os.path.isabs(args.output_prefix):
raw_path = os.path.join(base_dir, args.output_prefix)
args.output_prefix = os.path.realpath(raw_path)
if not args.output_prefix.startswith(base_dir):
raise ValueError(f"Invalid path: {args.output_prefix} is not within the allowed directory {base_dir} ")
def preprocess(config):
args = DataHandlerConfig(config)
resolve_relative_path(args)
validate_data_handler_config(args)
if args.merge_group_keys is not None:
merge_datasets(args)
return
tokenizer = get_tokenizer(args.tokenizer_name_or_path,
prompt_type=args.prompt_type,
prompt_type_path=args.prompt_type_path,
enable_thinking=args.enable_thinking)
splitter = build_splitter(args)
logger.info(f"building dataset: {args.input}")
raw_data = build_dataset(args)
if args.n_subs == 1:
handler = get_dataset_handler(args, raw_data, tokenizer, splitter)
# serialize to bin&idx
handler.serialize_to_disk()
else:
target_prefix = args.output_prefix
target_prefixname = os.path.basename(target_prefix)
num_samples = len(raw_data)
start_ends = cut_range_to_subs(num_samples, num_samples // args.n_subs)
subsets = [raw_data.select(range(x[0], x[1])) for x in start_ends]
# multiprocessing
params_list = []
for k, subset in enumerate(subsets):
args_ = copy.deepcopy(args)
args_.output_prefix = target_prefix.replace(target_prefixname,
f'{str(k).zfill(3)}_of_{str(len(subsets) - 1).zfill(3)}_{target_prefixname}')
params = [args_, subset, tokenizer, splitter]
params_list.append(params)
pool = multiprocessing.Pool()
sub_idx_files = pool.map(handle_subset, params_list)
pool.close()
pool.join()
for key in sub_idx_files[0].keys():
idx_files = [x[key] for x in sub_idx_files]
idx_files.sort()
target_idx = idx_files[0].replace(f'000_of_{str(len(subsets) - 1).zfill(3)}_{target_prefixname}',
target_prefixname)
target_bin = target_idx.replace('.idx', '.bin')
idx = IndexedDatasetBuilder(target_bin)
for idx_file in idx_files:
idx.add_index(idx_file.replace('.idx', ''))
idx.finalize(target_idx)
for idx_file in idx_files:
os.remove(idx_file)
os.remove(idx_file.replace('.idx', '.bin'))
@hydra.main(config_path="configs/rlhf/datasets", config_name=config_name)
def main(config):
preprocess(config)
if __name__ == '__main__':
main()