Skip to content

Commit 739a1ca

Browse files
authored
feat: add benchmark_qps in test (#832)
1 parent 28bf517 commit 739a1ca

File tree

4 files changed

+374
-3
lines changed

4 files changed

+374
-3
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ def __init__(self, kvargs):
5858
self.use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False)
5959
self.data_type = kvargs.get("data_type", "float16")
6060
self.graph_max_batch_size = kvargs.get("graph_max_batch_size", 16)
61+
self.graph_max_batch_size = (
62+
self.graph_max_batch_size // 2
63+
if get_env_start_args().enable_decode_microbatch_overlap
64+
else self.graph_max_batch_size
65+
)
6166
self.graph_max_len_in_batch = kvargs.get("graph_max_len_in_batch", 8192)
6267
self.disable_cudagraph = kvargs.get("disable_cudagraph", False)
6368
self.quant_type = kvargs.get("quant_type", "none")

lightllm/utils/envs_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def enable_env_vars(args):
5454

5555
@lru_cache(maxsize=None)
5656
def get_deepep_num_max_dispatch_tokens_per_rank():
57+
# 该参数需要大于单卡最大batch size,且是8的倍数。该参数与显存占用直接相关,值越大,显存占用越大,如果出现显存不足,可以尝试调小该值
5758
return int(os.getenv("NUM_MAX_DISPATCH_TOKENS_PER_RANK", 256))
5859

5960

test/benchmark_qps.py

Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
import os
2+
import argparse
3+
import yaml
4+
import requests
5+
import json
6+
import time
7+
import random
8+
import numpy as np
9+
from tqdm import tqdm
10+
from typing import Union, List, Tuple
11+
from concurrent.futures import ThreadPoolExecutor
12+
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
13+
import aiohttp
14+
import asyncio
15+
16+
17+
def seed_all(seed):
18+
random.seed(seed)
19+
os.environ["PYTHONHASHSEED"] = str(seed)
20+
np.random.seed(seed)
21+
22+
23+
def get_tokenizer(
24+
tokenizer_name: str,
25+
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
26+
"""Gets a tokenizer for the given model name via Huggingface."""
27+
28+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
29+
return tokenizer
30+
31+
32+
def get_output_length(reqs_num: int, output_len: int) -> List[int]:
33+
min_len, max_len = 2, output_len * 2
34+
mean = (min_len + max_len) * 0.5
35+
std = mean
36+
output_lens = []
37+
for _ in range(reqs_num):
38+
cur_len = random.gauss(mean, std)
39+
cur_len = round(cur_len)
40+
if cur_len < min_len:
41+
cur_len = min_len
42+
elif cur_len > max_len:
43+
cur_len = max_len
44+
output_lens.append(cur_len)
45+
return output_lens
46+
47+
48+
def gen_random_input_text(tokenizer) -> str:
49+
random_ids = [random.randint(512, 8192) for _ in range(1024)]
50+
random_text = tokenizer.decode(random_ids)
51+
return random_text
52+
53+
54+
def gen_random_data(
55+
input_len: int, output_len: int, reqs_num: int, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
56+
) -> Tuple[List[str], List[int], List[int]]:
57+
prompts = []
58+
input_lens = []
59+
output_lens = get_output_length(reqs_num, output_len)
60+
for i in range(reqs_num):
61+
input_text = gen_random_input_text(tokenizer)
62+
prompts.append(input_text)
63+
input_lens.append(input_len)
64+
print("Generate random data finish.")
65+
return prompts, input_lens, output_lens
66+
67+
68+
async def async_post_stream_lightllm(url, text_input, max_new_tokens, session) -> List[float]:
69+
try:
70+
data = {
71+
"inputs": text_input,
72+
"parameters": {
73+
"do_sample": False,
74+
"ignore_eos": True,
75+
"max_new_tokens": max_new_tokens,
76+
},
77+
}
78+
headers = {"Content-Type": "application/json"}
79+
used_time = []
80+
start_time = time.time()
81+
last_time = start_time
82+
83+
async with session.post(url, headers=headers, json=data) as response:
84+
if response.status != 200:
85+
return []
86+
87+
async for line in response.content:
88+
if line and line.startswith(b"data:"):
89+
# print(line)
90+
current_time = time.time()
91+
elapsed_time = current_time - last_time
92+
used_time.append(elapsed_time)
93+
last_time = current_time
94+
return used_time
95+
except Exception:
96+
pass
97+
98+
99+
async def continuous_sender(
100+
session, pending_tasks, async_task, url, prompts, max_new_tokens, request_queue, stop_send, sent_count, input_qps
101+
):
102+
prompt_index = 0
103+
while not stop_send.is_set():
104+
prompt = prompts[prompt_index % len(prompts)]
105+
max_tokens = max_new_tokens[prompt_index % len(max_new_tokens)]
106+
107+
task = asyncio.create_task(async_task(url, prompt, max_tokens, session))
108+
pending_tasks.append(task)
109+
await request_queue.put(task)
110+
111+
prompt_index += 1
112+
sent_count[0] += 1
113+
# 控制发送速率
114+
await asyncio.sleep(1.0 / input_qps)
115+
116+
117+
async def response_collector(
118+
request_queue,
119+
results,
120+
reqs_num,
121+
stop_event,
122+
stop_send,
123+
counter,
124+
end_time,
125+
sent_count,
126+
force_terminate,
127+
pending_tasks,
128+
):
129+
try:
130+
while True:
131+
try:
132+
task = await asyncio.wait_for(request_queue.get(), timeout=1.0)
133+
result = await task
134+
request_queue.task_done()
135+
assert result is not None
136+
if len(result) > 1 and not stop_send.is_set():
137+
results.append(result)
138+
current_count = counter[0] + 1
139+
counter[0] = current_count
140+
print(f"\rfinished_reqs:{current_count} / target_reqs:{reqs_num} / sent_reqs:{sent_count[0]}", end="")
141+
142+
if len(results) >= reqs_num and not stop_send.is_set():
143+
end_time[0] = time.time()
144+
print("\nReached target number of responses")
145+
stop_send.set()
146+
if force_terminate and not stop_event.is_set():
147+
stop_event.set()
148+
else:
149+
print("\nWaiting remining responses to finish...")
150+
151+
if current_count >= sent_count[0] and not stop_event.is_set():
152+
stop_event.set()
153+
154+
if stop_event.is_set() and (force_terminate or request_queue.empty()):
155+
return
156+
157+
except asyncio.TimeoutError:
158+
if stop_event.is_set() and (force_terminate or request_queue.empty()):
159+
return
160+
continue
161+
except Exception as e:
162+
print(f"\nError collecting response: {e}")
163+
finally:
164+
if force_terminate:
165+
for task in pending_tasks:
166+
if not task.done():
167+
task.cancel()
168+
169+
170+
async def run_continuous_benchmark(
171+
async_task, url, prompts, max_new_tokens, reqs_num, num_clients, input_qps, force_terminate
172+
):
173+
request_queue = asyncio.Queue()
174+
stop_event = asyncio.Event()
175+
stop_send = asyncio.Event()
176+
results_data = []
177+
counter = [0]
178+
sent_count = [0]
179+
end_time = [0.0]
180+
pending_tasks = []
181+
182+
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(limit=10 * reqs_num)) as session:
183+
sender_task = asyncio.create_task(
184+
continuous_sender(
185+
session,
186+
pending_tasks,
187+
async_task,
188+
url,
189+
prompts,
190+
max_new_tokens,
191+
request_queue,
192+
stop_send,
193+
sent_count,
194+
input_qps,
195+
)
196+
)
197+
198+
collector_task = [
199+
asyncio.create_task(
200+
response_collector(
201+
request_queue,
202+
results_data,
203+
reqs_num,
204+
stop_event,
205+
stop_send,
206+
counter,
207+
end_time,
208+
sent_count,
209+
force_terminate,
210+
pending_tasks,
211+
)
212+
)
213+
for _ in range(num_clients)
214+
]
215+
await asyncio.wait(collector_task)
216+
217+
if not sender_task.done():
218+
sender_task.cancel()
219+
try:
220+
await sender_task
221+
except asyncio.CancelledError:
222+
pass
223+
224+
return results_data, sent_count[0], end_time[0]
225+
226+
227+
model_name = []
228+
229+
230+
def main():
231+
parser = argparse.ArgumentParser()
232+
parser.add_argument("--url", type=str, default="http://127.0.0.1:8000/generate_stream")
233+
parser.add_argument("--num_clients", type=int, default=100)
234+
parser.add_argument("--tokenizer_path", type=str, default=None)
235+
parser.add_argument("--input_num", type=int, default=2000)
236+
parser.add_argument("--input_qps", type=float, default=30.0)
237+
parser.add_argument("--input_len", type=int, default=1024)
238+
parser.add_argument("--output_len", type=int, default=128)
239+
parser.add_argument("--server_api", type=str, default="lightllm")
240+
parser.add_argument("--dump_file", type=str, default="")
241+
parser.add_argument("--seed", type=int, default=0)
242+
parser.add_argument(
243+
"--force_terminate",
244+
type=int,
245+
default=0,
246+
help="0: waiting all reqs return; 1: only waiting input_num reqs return",
247+
)
248+
249+
args = parser.parse_args()
250+
if args.dump_file and os.path.exists(args.dump_file):
251+
# 读取并输出 JSON 内容
252+
with open(args.dump_file, "r") as json_file:
253+
content = json.load(json_file)
254+
print(json.dumps(content, indent=4))
255+
return
256+
257+
assert args.tokenizer_path is not None
258+
model_name.append(args.tokenizer_path)
259+
seed_all(args.seed)
260+
url = args.url
261+
tokenizer = get_tokenizer(args.tokenizer_path)
262+
# qps发送模式发送请求的数量不固定,这里暂定为reqs_num的10倍
263+
prompts, input_lens, max_new_tokens = gen_random_data(
264+
args.input_len, args.output_len, 10 * args.input_num, tokenizer
265+
)
266+
267+
percentiles = [25, 50, 75, 90, 95, 99, 100]
268+
if args.server_api == "lightllm":
269+
async_post_stream = async_post_stream_lightllm
270+
else:
271+
raise Exception(f"Not support {args.server_api} server_api.")
272+
273+
dump_dict = {}
274+
dump_dict["backend"] = args.server_api
275+
dump_dict["clients"] = args.num_clients
276+
277+
loop = asyncio.new_event_loop()
278+
asyncio.set_event_loop(loop)
279+
start_time = time.time()
280+
results, sent_reqs, end_time = loop.run_until_complete(
281+
run_continuous_benchmark(
282+
async_post_stream,
283+
url,
284+
prompts,
285+
max_new_tokens,
286+
args.input_num,
287+
args.num_clients,
288+
args.input_qps,
289+
args.force_terminate,
290+
)
291+
)
292+
loop.close()
293+
294+
first_token_time = []
295+
decode_token_time = []
296+
request_time = []
297+
final_output_lens = []
298+
valid_num = 0
299+
for result in results:
300+
if len(result) > 1: # 统计至少decode出两个token的数据
301+
first_token_time.append(result[0])
302+
decode_token_time.append(sum(result[1:]) / len(result[1:]))
303+
request_time.append(sum(result))
304+
final_output_lens.append(len(result))
305+
valid_num += 1
306+
307+
print(
308+
f"\n\nvalid num = {valid_num}; all data num = {len(results)}; valid ratio = {valid_num * 1.0 / len(results)}\n"
309+
)
310+
print(f"Total QPS: {valid_num / (end_time - start_time)}")
311+
print(f"Sender QPS: {sent_reqs / (end_time - start_time)}")
312+
print(f"Avg Input Length: {sum(input_lens) / len(input_lens)}")
313+
print(f"Avg Output Length: {sum(final_output_lens) / len(final_output_lens)}")
314+
print(f"Total Throughput: {(sum(input_lens) + sum(final_output_lens)) / (end_time - start_time)} token/s")
315+
print(f"Input Throughput: {sum(input_lens) / (end_time - start_time)} token/s")
316+
print(f"Output Throughput: {sum(final_output_lens) / (end_time - start_time)} token/s")
317+
print("-" * 10)
318+
dump_dict["request_num"] = valid_num
319+
dump_dict["Total QPS"] = valid_num / (end_time - start_time)
320+
dump_dict["Sender QPS"] = sent_reqs / (end_time - start_time)
321+
dump_dict["Avg Input Length"] = sum(input_lens) / len(input_lens)
322+
dump_dict["Avg Output Length"] = sum(final_output_lens) / len(final_output_lens)
323+
dump_dict["Total Throughput"] = (sum(input_lens) + sum(final_output_lens)) / (end_time - start_time)
324+
dump_dict["Input Throughput"] = sum(input_lens) / (end_time - start_time)
325+
dump_dict["Output Throughput"] = sum(final_output_lens) / (end_time - start_time)
326+
327+
values = np.percentile(request_time, percentiles)
328+
request_time_dict = {}
329+
for percentile, value in zip(percentiles, values):
330+
print(f"request_time P{percentile}: {value:.6f}s")
331+
request_time_dict[f"P{percentile}"] = value
332+
dump_dict["request_time"] = request_time_dict
333+
print("-" * 10)
334+
335+
first_token_time_dict = {}
336+
values = np.percentile(first_token_time, percentiles)
337+
for percentile, value in zip(percentiles, values):
338+
print(f"first_token_time P{percentile}: {value:.6f}s")
339+
first_token_time_dict[f"P{percentile}"] = value
340+
dump_dict["first_token_time_dict"] = first_token_time_dict
341+
print("-" * 10)
342+
343+
decode_token_time_dict = {}
344+
values = np.percentile(decode_token_time, percentiles)
345+
for percentile, value in zip(percentiles, values):
346+
print(f"decode_token_time P{percentile}: {value * 1000:.6f}ms")
347+
decode_token_time_dict[f"P{percentile}"] = value * 1000
348+
dump_dict["decode_token_time_dict"] = decode_token_time_dict
349+
print(dump_dict)
350+
351+
if args.dump_file:
352+
with open(args.dump_file, "w") as json_file:
353+
json.dump(dump_dict, json_file, indent=4)
354+
print(f"Results have been written to {args.dump_file}")
355+
356+
357+
if __name__ == "__main__":
358+
main()

0 commit comments

Comments
 (0)