Skip to content

Commit 41b7ed7

Browse files
author
niushengxiao
committed
feat: add benchmark_qps in test
1 parent 1e01498 commit 41b7ed7

File tree

5 files changed

+374
-5
lines changed

5 files changed

+374
-5
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ def __init__(self, kvargs):
5858
self.use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False)
5959
self.data_type = kvargs.get("data_type", "float16")
6060
self.graph_max_batch_size = kvargs.get("graph_max_batch_size", 16)
61+
self.graph_max_batch_size = (
62+
self.graph_max_batch_size // 2
63+
if get_env_start_args().enable_decode_microbatch_overlap
64+
else self.graph_max_batch_size
65+
)
6166
self.graph_max_len_in_batch = kvargs.get("graph_max_len_in_batch", 8192)
6267
self.disable_cudagraph = kvargs.get("disable_cudagraph", False)
6368
self.quant_type = kvargs.get("quant_type", "none")

lightllm/common/basemodel/cuda_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def warmup(self, model):
193193
@torch.no_grad()
194194
def warmup_overlap(self, model):
195195
logger.info("Begin capture overlap cudagraph, use the --disable_cudagraph to disable it.")
196-
for batch_size in range(self.max_batch_size, 0, -1):
196+
for batch_size in range(self.max_batch_size // 2, 0, -1):
197197
decode_batches = []
198198
for micro_batch_index in [0, 1]:
199199
# dummy prefill

lightllm/utils/envs_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def enable_env_vars(args):
5454

5555
@lru_cache(maxsize=None)
5656
def get_deepep_num_max_dispatch_tokens_per_rank():
57-
return int(os.getenv("NUM_MAX_DISPATCH_TOKENS_PER_RANK", 256))
57+
return int(os.getenv("NUM_MAX_DISPATCH_TOKENS_PER_RANK", 128))
5858

5959

6060
def get_lightllm_gunicorn_time_out_seconds():

test/benchmark_qps.py

Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
import os
2+
import argparse
3+
import yaml
4+
import requests
5+
import json
6+
import time
7+
import random
8+
import numpy as np
9+
from tqdm import tqdm
10+
from typing import Union, List, Tuple
11+
from concurrent.futures import ThreadPoolExecutor
12+
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
13+
import aiohttp
14+
import asyncio
15+
16+
17+
def seed_all(seed):
18+
random.seed(seed)
19+
os.environ["PYTHONHASHSEED"] = str(seed)
20+
np.random.seed(seed)
21+
22+
23+
def get_tokenizer(
24+
tokenizer_name: str,
25+
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
26+
"""Gets a tokenizer for the given model name via Huggingface."""
27+
28+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
29+
return tokenizer
30+
31+
32+
def get_output_length(reqs_num: int, output_len: int) -> List[int]:
33+
min_len, max_len = 2, output_len * 2
34+
mean = (min_len + max_len) * 0.5
35+
std = mean
36+
output_lens = []
37+
for _ in range(reqs_num):
38+
cur_len = random.gauss(mean, std)
39+
cur_len = round(cur_len)
40+
if cur_len < min_len:
41+
cur_len = min_len
42+
elif cur_len > max_len:
43+
cur_len = max_len
44+
output_lens.append(cur_len)
45+
return output_lens
46+
47+
48+
def gen_random_input_text(tokenizer) -> str:
49+
random_ids = [random.randint(512, 8192) for _ in range(1024)]
50+
random_text = tokenizer.decode(random_ids)
51+
return random_text
52+
53+
54+
def gen_random_data(
55+
input_len: int, output_len: int, reqs_num: int, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
56+
) -> Tuple[List[str], List[int], List[int]]:
57+
prompts = []
58+
input_lens = []
59+
output_lens = get_output_length(reqs_num, output_len)
60+
for i in range(reqs_num):
61+
input_text = gen_random_input_text(tokenizer)
62+
prompts.append(input_text)
63+
input_lens.append(input_len)
64+
print("Generate random data finish.")
65+
return prompts, input_lens, output_lens
66+
67+
68+
async def async_post_stream_lightllm(url, text_input, max_new_tokens, session) -> List[float]:
69+
try:
70+
data = {
71+
"inputs": text_input,
72+
"parameters": {
73+
"do_sample": False,
74+
"ignore_eos": True,
75+
"max_new_tokens": max_new_tokens,
76+
},
77+
}
78+
headers = {"Content-Type": "application/json"}
79+
used_time = []
80+
start_time = time.time()
81+
last_time = start_time
82+
83+
async with session.post(url, headers=headers, json=data) as response:
84+
if response.status != 200:
85+
return []
86+
87+
async for line in response.content:
88+
if line and line.startswith(b"data:"):
89+
# print(line)
90+
current_time = time.time()
91+
elapsed_time = current_time - last_time
92+
used_time.append(elapsed_time)
93+
last_time = current_time
94+
return used_time
95+
except Exception:
96+
pass
97+
98+
99+
async def continuous_sender(
100+
session, pending_tasks, async_task, url, prompts, max_new_tokens, request_queue, stop_send, sent_count, input_qps
101+
):
102+
prompt_index = 0
103+
while not stop_send.is_set():
104+
prompt = prompts[prompt_index % len(prompts)]
105+
max_tokens = max_new_tokens[prompt_index % len(max_new_tokens)]
106+
107+
task = asyncio.create_task(async_task(url, prompt, max_tokens, session))
108+
pending_tasks.append(task)
109+
await request_queue.put(task)
110+
111+
prompt_index += 1
112+
sent_count[0] += 1
113+
# 控制发送速率
114+
await asyncio.sleep(1.0 / input_qps)
115+
116+
117+
async def response_collector(
118+
request_queue,
119+
results,
120+
reqs_num,
121+
stop_event,
122+
stop_send,
123+
counter,
124+
end_time,
125+
sent_count,
126+
force_terminate,
127+
pending_tasks,
128+
):
129+
try:
130+
while True:
131+
try:
132+
task = await asyncio.wait_for(request_queue.get(), timeout=1.0)
133+
result = await task
134+
request_queue.task_done()
135+
if len(result) > 1 and not stop_send.is_set():
136+
results.append(result)
137+
current_count = counter[0] + 1
138+
counter[0] = current_count
139+
print(f"\rfinished_reqs:{current_count} / target_reqs:{reqs_num} / sent_reqs:{sent_count[0]}", end="")
140+
141+
if len(results) >= reqs_num and not stop_send.is_set():
142+
end_time[0] = time.time()
143+
print("\nReached target number of responses")
144+
stop_send.set()
145+
if force_terminate and not stop_event.is_set():
146+
stop_event.set()
147+
else:
148+
print("\nWaiting remining responses to finish...")
149+
150+
if current_count >= sent_count[0] and not stop_event.is_set():
151+
stop_event.set()
152+
153+
if stop_event.is_set() and (force_terminate or request_queue.empty()):
154+
return
155+
156+
except asyncio.TimeoutError:
157+
if stop_event.is_set() and (force_terminate or request_queue.empty()):
158+
return
159+
continue
160+
except Exception as e:
161+
print(f"\nError collecting response: {e}")
162+
finally:
163+
if force_terminate:
164+
for task in pending_tasks:
165+
if not task.done():
166+
task.cancel()
167+
168+
169+
async def run_continuous_benchmark(
170+
async_task, url, prompts, max_new_tokens, reqs_num, num_clients, input_qps, force_terminate
171+
):
172+
request_queue = asyncio.Queue()
173+
stop_event = asyncio.Event()
174+
stop_send = asyncio.Event()
175+
results_data = []
176+
counter = [0]
177+
sent_count = [0]
178+
end_time = [0.0]
179+
pending_tasks = []
180+
181+
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(limit=10 * reqs_num)) as session:
182+
sender_task = asyncio.create_task(
183+
continuous_sender(
184+
session,
185+
pending_tasks,
186+
async_task,
187+
url,
188+
prompts,
189+
max_new_tokens,
190+
request_queue,
191+
stop_send,
192+
sent_count,
193+
input_qps,
194+
)
195+
)
196+
197+
collector_task = [
198+
asyncio.create_task(
199+
response_collector(
200+
request_queue,
201+
results_data,
202+
reqs_num,
203+
stop_event,
204+
stop_send,
205+
counter,
206+
end_time,
207+
sent_count,
208+
force_terminate,
209+
pending_tasks,
210+
)
211+
)
212+
for _ in range(num_clients)
213+
]
214+
await asyncio.wait(collector_task)
215+
216+
if not sender_task.done():
217+
sender_task.cancel()
218+
try:
219+
await sender_task
220+
except asyncio.CancelledError:
221+
pass
222+
223+
return results_data, sent_count[0], end_time[0]
224+
225+
226+
model_name = []
227+
228+
229+
def main():
230+
parser = argparse.ArgumentParser()
231+
parser.add_argument("--url", type=str, default="http://127.0.0.1:8000/generate_stream")
232+
parser.add_argument("--num_clients", type=int, default=100)
233+
parser.add_argument("--tokenizer_path", type=str, default=None)
234+
parser.add_argument("--input_num", type=int, default=2000)
235+
parser.add_argument("--input_qps", type=float, default=30.0)
236+
parser.add_argument("--input_len", type=int, default=1024)
237+
parser.add_argument("--output_len", type=int, default=128)
238+
parser.add_argument("--server_api", type=str, default="lightllm")
239+
parser.add_argument("--dump_file", type=str, default="")
240+
parser.add_argument("--seed", type=int, default=0)
241+
parser.add_argument(
242+
"--force_terminate",
243+
type=int,
244+
default=0,
245+
help="0: waiting all reqs return; 1: only waiting input_num reqs return",
246+
)
247+
248+
args = parser.parse_args()
249+
if args.dump_file and os.path.exists(args.dump_file):
250+
# 读取并输出 JSON 内容
251+
with open(args.dump_file, "r") as json_file:
252+
content = json.load(json_file)
253+
print(json.dumps(content, indent=4))
254+
return
255+
256+
assert args.tokenizer_path is not None
257+
model_name.append(args.tokenizer_path)
258+
seed_all(args.seed)
259+
url = args.url
260+
tokenizer = get_tokenizer(args.tokenizer_path)
261+
# qps发送模式发送请求的数量不固定,这里暂定为reqs_num的10倍
262+
prompts, input_lens, max_new_tokens = gen_random_data(
263+
args.input_len, args.output_len, 10 * args.input_num, tokenizer
264+
)
265+
266+
percentiles = [25, 50, 75, 90, 95, 99, 100]
267+
if args.server_api == "lightllm":
268+
async_post_stream = async_post_stream_lightllm
269+
else:
270+
raise Exception(f"Not support {args.server_api} server_api.")
271+
272+
dump_dict = {}
273+
dump_dict["backend"] = args.server_api
274+
dump_dict["clients"] = args.num_clients
275+
276+
loop = asyncio.new_event_loop()
277+
asyncio.set_event_loop(loop)
278+
start_time = time.time()
279+
results, sent_reqs, end_time = loop.run_until_complete(
280+
run_continuous_benchmark(
281+
async_post_stream,
282+
url,
283+
prompts,
284+
max_new_tokens,
285+
args.input_num,
286+
args.num_clients,
287+
args.input_qps,
288+
args.force_terminate,
289+
)
290+
)
291+
loop.close()
292+
293+
first_token_time = []
294+
decode_token_time = []
295+
request_time = []
296+
final_output_lens = []
297+
valid_num = 0
298+
for result in results:
299+
if len(result) > 1: # 统计至少decode出两个token的数据
300+
first_token_time.append(result[0])
301+
decode_token_time.append(sum(result[1:]) / len(result[1:]))
302+
request_time.append(sum(result))
303+
final_output_lens.append(len(result))
304+
valid_num += 1
305+
306+
print(
307+
f"\n\nvalid num = {valid_num}; all data num = {len(results)}; valid ratio = {valid_num * 1.0 / len(results)}\n"
308+
)
309+
print(f"Total QPS: {valid_num / (end_time - start_time)}")
310+
print(f"Sender QPS: {sent_reqs / (end_time - start_time)}")
311+
print(f"Avg Input Length: {sum(input_lens) / len(input_lens)}")
312+
print(f"Avg Output Length: {sum(final_output_lens) / len(final_output_lens)}")
313+
print(f"Total Throughput: {(sum(input_lens) + sum(final_output_lens)) / (end_time - start_time)} token/s")
314+
print(f"Input Throughput: {sum(input_lens) / (end_time - start_time)} token/s")
315+
print(f"Output Throughput: {sum(final_output_lens) / (end_time - start_time)} token/s")
316+
print("-" * 10)
317+
dump_dict["request_num"] = valid_num
318+
dump_dict["Total QPS"] = valid_num / (end_time - start_time)
319+
dump_dict["Sender QPS"] = sent_reqs / (end_time - start_time)
320+
dump_dict["Avg Input Length"] = sum(input_lens) / len(input_lens)
321+
dump_dict["Avg Output Length"] = sum(final_output_lens) / len(final_output_lens)
322+
dump_dict["Total Throughput"] = (sum(input_lens) + sum(final_output_lens)) / (end_time - start_time)
323+
dump_dict["Input Throughput"] = sum(input_lens) / (end_time - start_time)
324+
dump_dict["Output Throughput"] = sum(final_output_lens) / (end_time - start_time)
325+
326+
values = np.percentile(request_time, percentiles)
327+
request_time_dict = {}
328+
for percentile, value in zip(percentiles, values):
329+
print(f"request_time P{percentile}: {value:.6f}s")
330+
request_time_dict[f"P{percentile}"] = value
331+
dump_dict["request_time"] = request_time_dict
332+
print("-" * 10)
333+
334+
first_token_time_dict = {}
335+
values = np.percentile(first_token_time, percentiles)
336+
for percentile, value in zip(percentiles, values):
337+
print(f"first_token_time P{percentile}: {value:.6f}s")
338+
first_token_time_dict[f"P{percentile}"] = value
339+
dump_dict["first_token_time_dict"] = first_token_time_dict
340+
print("-" * 10)
341+
342+
decode_token_time_dict = {}
343+
values = np.percentile(decode_token_time, percentiles)
344+
for percentile, value in zip(percentiles, values):
345+
print(f"decode_token_time P{percentile}: {value * 1000:.6f}ms")
346+
decode_token_time_dict[f"P{percentile}"] = value * 1000
347+
dump_dict["decode_token_time_dict"] = decode_token_time_dict
348+
print(dump_dict)
349+
350+
if args.dump_file:
351+
with open(args.dump_file, "w") as json_file:
352+
json.dump(dump_dict, json_file, indent=4)
353+
print(f"Results have been written to {args.dump_file}")
354+
355+
356+
if __name__ == "__main__":
357+
main()

0 commit comments

Comments
 (0)