Skip to content

Commit 90b321c

Browse files
committed
Merge branch 'dev'
2 parents 82999d6 + e4d18f8 commit 90b321c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+6372
-1840
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
1+
**/__pycache__
2+
**/*.pyc
23
.DS_Store

Data/Ruler/prepare.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Prepare jsonl with field `input` and `outputs`.
17+
{
18+
"index" int,
19+
"input": str,
20+
"outputs": [str],
21+
}
22+
23+
python prepare.py \
24+
--save_dir ./ \
25+
--benchmark synthetic \
26+
--task niah_single_1 \
27+
--tokenizer_path tokenizer.model \
28+
--tokenizer_type nemo \
29+
--max_seq_length 4096 \
30+
--model_template_type base \
31+
--num_samples 10 \
32+
"""
33+
import os
34+
import argparse
35+
import importlib
36+
import subprocess
37+
import time
38+
import yaml
39+
from pathlib import Path
40+
from template import Templates
41+
import nltk
42+
try:
43+
nltk.data.find('tokenizers/punkt')
44+
except LookupError:
45+
nltk.download('punkt')
46+
47+
48+
parser = argparse.ArgumentParser()
49+
parser.add_argument("--save_dir", type=Path, required=True, help='dataset folder to save dataset')
50+
parser.add_argument("--benchmark", type=str, default='synthetic', help='Options: [synthetic]')
51+
parser.add_argument("--task", type=str, required=True, help='tasks in benchmark')
52+
parser.add_argument("--subset", type=str, default='validation', help='Options: validation or test')
53+
parser.add_argument("--tokenizer_path", type=str, required=True, help='path to the tokenizer model')
54+
parser.add_argument("--tokenizer_type", type=str, default='nemo', help='[Options] nemo, hf, openai.')
55+
parser.add_argument("--max_seq_length", type=int, required=True, help='max sequence length including all input tokens and generated tokens.')
56+
parser.add_argument("--num_samples", type=int, default=500, help='maximum number of samples we want to test')
57+
parser.add_argument("--random_seed", type=int, default=42)
58+
parser.add_argument("--model_template_type", type=str, default='base', help='Options in `template.py`')
59+
parser.add_argument("--remove_newline_tab", action='store_true', help='remove `\n` and `\t` in all strings.')
60+
parser.add_argument("--chunk_idx", type=int, default=0, help='index of current split chunk')
61+
parser.add_argument("--chunk_amount", type=int, default=1, help='size of split chunk')
62+
63+
args = parser.parse_args()
64+
65+
def main():
66+
start_time = time.time()
67+
curr_folder = os.path.dirname(os.path.abspath(__file__))
68+
69+
try:
70+
module = importlib.import_module(f"{args.benchmark}.constants")
71+
except ImportError:
72+
print(f"Module data.{args.benchmark}.constants not found.")
73+
74+
tasks_base = module.TASKS
75+
with open(os.path.join(curr_folder, f"{args.benchmark}.yaml"), "r") as f:
76+
tasks_customized = yaml.safe_load(f)
77+
78+
if args.task not in tasks_customized:
79+
raise ValueError(f'{args.task} is not found in config_tasks.yaml')
80+
81+
config = tasks_customized.get(args.task)
82+
config.update(tasks_base[config['task']])
83+
84+
# Add templates
85+
assert args.model_template_type in Templates, print(f'{args.model_template_type} is not found in {Templates.keys()}')
86+
model_template = Templates[args.model_template_type]
87+
task_template = config['template']
88+
89+
# Add answer prefix for all models
90+
answer_prefix = config['answer_prefix'] if 'answer_prefix' in config else ''
91+
config['template'] = model_template.format(task_template=task_template) + answer_prefix
92+
93+
# Split task into multiple chunks
94+
chunks = [(args.num_samples // args.chunk_amount) + (1 if i < args.num_samples % args.chunk_amount else 0) for i in range(args.chunk_amount)]
95+
num_samples = chunks[args.chunk_idx]
96+
pre_samples = sum(chunks[:args.chunk_idx])
97+
98+
random_seed = 42 + args.chunk_idx
99+
100+
101+
save_file = args.save_dir / args.task / f"{args.subset}.jsonl"
102+
file_exists = False
103+
if os.path.exists(save_file):
104+
with open(save_file, "r") as f:
105+
data = f.readlines()
106+
if len(data) == args.num_samples: file_exists = True
107+
108+
if not file_exists:
109+
try:
110+
script = os.path.join(curr_folder, args.benchmark, f"{config['task']}.py")
111+
additional_args = " ".join([f"--{k} {v}" for k, v in config['args'].items()])
112+
command = f"""python {script} \
113+
--save_dir {args.save_dir} \
114+
--save_name {args.task} \
115+
--subset {args.subset} \
116+
--tokenizer_path {args.tokenizer_path} \
117+
--tokenizer_type {args.tokenizer_type} \
118+
--max_seq_length {args.max_seq_length} \
119+
--tokens_to_generate {config['tokens_to_generate']} \
120+
--num_samples {num_samples} \
121+
--random_seed {random_seed} \
122+
{additional_args} \
123+
{f"--remove_newline_tab" if args.remove_newline_tab else ""} \
124+
{f"--pre_samples {pre_samples}" if config['task'] == 'qa' else ""} \
125+
--template "{config['template']}"
126+
"""
127+
print(command)
128+
result = subprocess.run(command,
129+
shell=True,
130+
check=True,
131+
stdout=subprocess.PIPE,
132+
stderr=subprocess.PIPE,
133+
text=True)
134+
135+
if result.returncode == 0:
136+
print("Output:")
137+
print(result.stdout)
138+
else:
139+
print("Error:")
140+
print(result.stderr)
141+
except subprocess.CalledProcessError as e:
142+
print("Error output:", e.stderr)
143+
144+
print(f"Prepare {args.task} with lines: {args.num_samples} to {save_file}")
145+
print(f"Used time: {round((time.time() - start_time) / 60, 1)} minutes")
146+
else:
147+
print(f"Skip preparing {args.task} with lines: {args.num_samples} to {save_file} (file exists)")
148+
149+
if __name__ == '__main__':
150+
main()

Data/Ruler/synthetic.yaml

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
niah_single_1:
16+
task: niah
17+
args:
18+
type_haystack: repeat
19+
type_needle_k: words
20+
type_needle_v: numbers
21+
num_needle_k: 1
22+
num_needle_v: 1
23+
num_needle_q: 1
24+
25+
niah_single_2:
26+
task: niah
27+
args:
28+
type_haystack: essay
29+
type_needle_k: words
30+
type_needle_v: numbers
31+
num_needle_k: 1
32+
num_needle_v: 1
33+
num_needle_q: 1
34+
35+
niah_single_3:
36+
task: niah
37+
args:
38+
type_haystack: essay
39+
type_needle_k: words
40+
type_needle_v: uuids
41+
num_needle_k: 1
42+
num_needle_v: 1
43+
num_needle_q: 1
44+
45+
niah_multikey_1:
46+
task: niah
47+
args:
48+
type_haystack: essay
49+
type_needle_k: words
50+
type_needle_v: numbers
51+
num_needle_k: 4
52+
num_needle_v: 1
53+
num_needle_q: 1
54+
55+
niah_multikey_2:
56+
task: niah
57+
args:
58+
type_haystack: needle
59+
type_needle_k: words
60+
type_needle_v: numbers
61+
num_needle_k: 1
62+
num_needle_v: 1
63+
num_needle_q: 1
64+
65+
niah_multikey_3:
66+
task: niah
67+
args:
68+
type_haystack: needle
69+
type_needle_k: uuids
70+
type_needle_v: uuids
71+
num_needle_k: 1
72+
num_needle_v: 1
73+
num_needle_q: 1
74+
75+
niah_multivalue:
76+
task: niah
77+
args:
78+
type_haystack: essay
79+
type_needle_k: words
80+
type_needle_v: numbers
81+
num_needle_k: 1
82+
num_needle_v: 4
83+
num_needle_q: 1
84+
85+
niah_multiquery:
86+
task: niah
87+
args:
88+
type_haystack: essay
89+
type_needle_k: words
90+
type_needle_v: numbers
91+
num_needle_k: 1
92+
num_needle_v: 1
93+
num_needle_q: 4
94+
95+
vt:
96+
task: variable_tracking
97+
args:
98+
num_chains: 1
99+
num_hops: 4
100+
101+
cwe:
102+
task: common_words_extraction
103+
args:
104+
freq_cw: 30
105+
freq_ucw: 3
106+
num_cw: 10
107+
108+
fwe:
109+
task: freq_words_extraction
110+
args:
111+
alpha: 2.0
112+
113+
qa_1:
114+
task: qa
115+
args:
116+
dataset: squad
117+
118+
qa_2:
119+
task: qa
120+
args:
121+
dataset: hotpotqa

0 commit comments

Comments
 (0)