Skip to content

Commit 8c8089f

Browse files
committed
add gsmk test script
1 parent 080ae7e commit 8c8089f

File tree

1 file changed

+230
-0
lines changed

1 file changed

+230
-0
lines changed

test/test_api/test_gsmk.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
# Adapted from https://github.com/sgl-project/sglang/blob/main/benchmark/gsm8k/bench_other.py
2+
import argparse
3+
import ast
4+
import json
5+
import os
6+
import re
7+
import time
8+
from concurrent.futures import ThreadPoolExecutor
9+
from typing import Optional
10+
11+
import numpy as np
12+
import requests
13+
from tqdm import tqdm
14+
15+
INVALID = -9999999
16+
17+
18+
def read_jsonl(filename: str):
19+
"""Read a JSONL file."""
20+
with open(filename) as fin:
21+
for line in fin:
22+
if line.startswith("#"):
23+
continue
24+
yield json.loads(line)
25+
26+
27+
def dump_state_text(filename: str, states: list, mode: str = "w"):
28+
"""Dump program state in a text file."""
29+
with open(filename, mode) as fout:
30+
for i, s in enumerate(states):
31+
if isinstance(s, str):
32+
fout.write(f"==== {i} ====\n{s}\n")
33+
else:
34+
fout.write(f"==== {i} ====\n{str(s)}\n")
35+
36+
37+
def download_and_cache_file(url: str, filename: Optional[str] = None):
38+
"""Read and cache a file from a url."""
39+
if filename is None:
40+
filename = os.path.join("/tmp", url.split("/")[-1])
41+
42+
# Check if the cache file already exists
43+
if os.path.exists(filename):
44+
return filename
45+
46+
print(f"Downloading from {url} to {filename}")
47+
48+
# Stream the response to show the progress bar
49+
response = requests.get(url, stream=True)
50+
response.raise_for_status() # Check for request errors
51+
52+
# Total size of the file in bytes
53+
total_size = int(response.headers.get("content-length", 0))
54+
chunk_size = 1024 # Download in chunks of 1KB
55+
56+
# Use tqdm to display the progress bar
57+
with open(filename, "wb") as file, tqdm(
58+
desc="Downloading",
59+
total=total_size,
60+
unit="iB",
61+
unit_scale=True,
62+
unit_divisor=1024,
63+
) as bar:
64+
for chunk in response.iter_content(chunk_size=chunk_size):
65+
size = file.write(chunk)
66+
bar.update(size)
67+
68+
return filename
69+
70+
71+
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
72+
"""Call LightLLM API for text generation."""
73+
assert url is not None
74+
75+
data = {
76+
"inputs": prompt,
77+
"parameters": {
78+
"temperature": temperature,
79+
"max_new_tokens": max_tokens,
80+
"stop_sequences": stop,
81+
},
82+
}
83+
res = requests.post(url, json=data)
84+
assert res.status_code == 200, f"API request failed with status code {res.status_code}: {res.text}"
85+
86+
response_json = res.json()
87+
if "generated_text" not in response_json:
88+
raise ValueError(f"Invalid API response format. Expected 'generated_text' key, got: {response_json.keys()}")
89+
if not isinstance(response_json["generated_text"], list) or len(response_json["generated_text"]) == 0:
90+
raise ValueError(
91+
"Invalid API response format. 'generated_text' should be a non-empty list, "
92+
f"got: {response_json['generated_text']}"
93+
)
94+
95+
pred = response_json["generated_text"][0]
96+
return pred
97+
98+
99+
def get_one_example(lines, i, include_answer):
100+
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
101+
if include_answer:
102+
ret += " " + lines[i]["answer"]
103+
return ret
104+
105+
106+
def get_few_shot_examples(lines, k):
107+
ret = ""
108+
for i in range(k):
109+
ret += get_one_example(lines, i, True) + "\n\n"
110+
return ret
111+
112+
113+
def get_answer_value(answer_str):
114+
answer_str = answer_str.replace(",", "")
115+
numbers = re.findall(r"\d+", answer_str)
116+
if len(numbers) < 1:
117+
return INVALID
118+
try:
119+
return ast.literal_eval(numbers[-1])
120+
except SyntaxError:
121+
return INVALID
122+
123+
124+
def parse_args():
125+
"""Parse command line arguments."""
126+
parser = argparse.ArgumentParser()
127+
parser.add_argument("--parallel", type=int, default=64)
128+
parser.add_argument("--host", type=str, default="http://127.0.0.1")
129+
parser.add_argument("--port", type=int, default=8000)
130+
parser.add_argument("--num-shots", type=int, default=5)
131+
parser.add_argument("--num-questions", type=int, default=200)
132+
parser.add_argument("--result-file", type=str, default="result.jsonl")
133+
parser.add_argument("--data-path", type=str, default="test.jsonl")
134+
return parser.parse_args()
135+
136+
137+
def main(args):
138+
# LightLLM API URL
139+
url = f"{args.host}:{args.port}/generate"
140+
141+
# Read data
142+
url_data = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
143+
filename = download_and_cache_file(url_data)
144+
lines = list(read_jsonl(filename))
145+
146+
# Construct prompts
147+
num_questions = args.num_questions
148+
num_shots = args.num_shots
149+
few_shot_examples = get_few_shot_examples(lines, num_shots)
150+
151+
# Ensure we have enough samples and avoid data leakage
152+
# Test questions should start after few-shot examples
153+
max_available = len(lines) - num_shots
154+
if num_questions > max_available:
155+
print(
156+
"Warning: Requested {} questions, but only {} available after reserving {} for few-shot. "
157+
"Using {} questions.".format(num_questions, max_available, num_shots, max_available)
158+
)
159+
num_questions = max_available
160+
161+
questions = []
162+
labels = []
163+
for i in range(num_shots, num_shots + num_questions):
164+
questions.append(get_one_example(lines, i, False))
165+
labels.append(get_answer_value(lines[i]["answer"]))
166+
assert all(label != INVALID for label in labels)
167+
168+
states = [None] * len(labels)
169+
170+
# Run requests using thread pool
171+
def get_one_answer(i):
172+
answer = call_generate_lightllm(
173+
prompt=few_shot_examples + questions[i],
174+
temperature=0,
175+
max_tokens=256,
176+
stop=["Question", "Assistant:", "<|separator|>"],
177+
url=url,
178+
)
179+
states[i] = answer
180+
181+
tic = time.perf_counter()
182+
if args.parallel == 1:
183+
for i in tqdm(range(len(questions))):
184+
get_one_answer(i)
185+
else:
186+
with ThreadPoolExecutor(args.parallel) as executor:
187+
list(
188+
tqdm(
189+
executor.map(get_one_answer, list(range(len(questions)))),
190+
total=len(questions),
191+
)
192+
)
193+
194+
latency = time.perf_counter() - tic
195+
196+
preds = []
197+
for i in range(len(states)):
198+
preds.append(get_answer_value(states[i]))
199+
200+
# Compute accuracy
201+
acc = np.mean(np.array(preds) == np.array(labels))
202+
invalid = np.mean(np.array(preds) == INVALID)
203+
204+
# Print results
205+
print(f"Accuracy: {acc:.3f}")
206+
print(f"Invalid: {invalid:.3f}")
207+
print(f"Latency: {latency:.3f} s")
208+
209+
# Dump results
210+
dump_state_text("tmp_output_lightllm.txt", states)
211+
212+
with open(args.result_file, "a") as fout:
213+
value = {
214+
"task": "gsm8k",
215+
"backend": "lightllm",
216+
"num_gpus": 1,
217+
"latency": round(latency, 3),
218+
"accuracy": round(acc, 3),
219+
"num_requests": args.num_questions,
220+
"other": {
221+
"num_questions": args.num_questions,
222+
"parallel": args.parallel,
223+
},
224+
}
225+
fout.write(json.dumps(value) + "\n")
226+
227+
228+
if __name__ == "__main__":
229+
args = parse_args()
230+
main(args)

0 commit comments

Comments
 (0)