77from typing import Optional
88
99import datasets
10+ import logging
1011import matplotlib .pyplot as plt
1112import numpy as np
1213import requests
1314from tqdm .contrib .concurrent import thread_map
1415
1516
17+ logging .basicConfig (level = logging .INFO )
18+ logger = logging .getLogger ("server-bench" )
19+
20+
1621def get_prompts (n_prompts : int ) -> list [str ]:
17- print ( " Loading MMLU dataset..." )
18- ret = datasets .load_dataset ("cais/mmlu" , "all" )["test" ]["question" ]
22+ logger . info ( " Loading MMLU dataset..." )
23+ ret = datasets .load_dataset ("cais/mmlu" , "all" )["test" ]["question" ] # type: ignore
1924 if n_prompts >= 0 :
2025 ret = ret [:n_prompts ]
2126 return ret
2227
2328
2429def get_server (path_server : str , path_model : str , path_log : Optional [str ], port : int , n_gpu_layers : int , parallel : int , ctx_size : int ) -> dict :
25- print ( " Starting the llama.cpp server..." )
30+ logger . info ( " Starting the llama.cpp server..." )
2631 address = f"http://localhost:{ port } "
2732
2833 popen_args : list [str ] = [
@@ -78,7 +83,7 @@ def get_prompt_length(data: dict) -> int:
7883 return len (tokens )
7984
8085
81- def send_prompt (data : dict ) -> tuple [int , float , list [float ]]:
86+ def send_prompt (data : dict ) -> tuple [float , list [float ]]:
8287 session = data ["session" ]
8388 server_address : str = data ["server_address" ]
8489
@@ -93,6 +98,7 @@ def send_prompt(data: dict) -> tuple[int, float, list[float]]:
9398 json_data : dict = {"prompt" : prompt , "seed" : data ["seed" ], "n_predict" : data ["n_predict" ], "stream" : True }
9499 response = session .post (f"{ server_address } /completion" , json = json_data , stream = True )
95100
101+ last_valid_line : str = ""
96102 token_arrival_times : list [float ] = []
97103 for line in response .iter_lines (decode_unicode = True ):
98104 if not line .startswith ("data: " ):
@@ -111,21 +117,20 @@ def send_prompt(data: dict) -> tuple[int, float, list[float]]:
111117def benchmark (path_server : str , path_model : str , path_log : Optional [str ], port : int , n_gpu_layers : int , parallel : int , ctx_size : int , n_prompts : int , n_predict : int ):
112118 prompts : list [str ] = get_prompts (n_prompts )
113119
114- server = None
120+ server : Optional [ dict ] = None
115121 try :
116- server : dict = get_server (path_server , path_model , path_log , port , n_gpu_layers , parallel , ctx_size )
122+ server = get_server (path_server , path_model , path_log , port , n_gpu_layers , parallel , ctx_size )
117123 server_address : str = server ["address" ]
118124
119125 with requests .Session () as session :
120126 data : list [dict ] = []
121127 for i , p in enumerate (prompts ):
122128 data .append ({"session" : session , "server_address" : server_address , "prompt" : p , "n_predict" : n_predict , "seed" : i })
123129
124- print ( " Getting the prompt lengths..." )
125- prompt_n : list [ int ] = [get_prompt_length (d ) for d in data ]
130+ logger . info ( " Getting the prompt lengths..." )
131+ prompt_n = [get_prompt_length (d ) for d in data ]
126132
127- print ("Starting the benchmark..." )
128- print ()
133+ logger .info (" Starting the benchmark...\n " )
129134 t0 = time ()
130135 results : list [tuple [int , list [float ]]] = thread_map (send_prompt , data , max_workers = parallel + 1 , chunksize = 1 )
131136 finally :
@@ -149,17 +154,17 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
149154 token_t -= t0
150155 token_t_last = np .max (token_t )
151156
152- print ( )
153- print (f"Benchmark duration: { token_t_last :.2f} s" )
154- print (f"Request throughput: { n_prompts / token_t_last :.2f} requests/s = { n_prompts / (token_t_last / 60 ):.2f} requests/min" )
155- print (f"Total prompt length: { np .sum (prompt_n )} tokens" )
156- print (f"Average prompt length: { np .mean (prompt_n ):.2f} tokens" )
157- print (f"Average prompt latency: { np .mean (prompt_ms ):.2f} ms" )
158- print (f"Average prompt speed: { np .sum (prompt_n ) / (1e-3 * np .sum (prompt_ms )):.2f} tokens/s" )
159- print (f"Total generated tokens: { token_t .shape [0 ]} " )
160- print (f"Average generation depth: { depth_sum / token_t .shape [0 ]:.2f} tokens" )
161- print (f"Average total generation speed: { token_t .shape [0 ] / token_t_last :.2f} tokens/s" )
162- print (f"Average generation speed per slot: { token_t .shape [0 ] / (parallel * token_t_last ):.2f} tokens/s / slot" )
157+ logger . info ( "" )
158+ logger . info (f" Benchmark duration: { token_t_last :.2f} s" )
159+ logger . info (f" Request throughput: { n_prompts / token_t_last :.2f} requests/s = { n_prompts / (token_t_last / 60 ):.2f} requests/min" )
160+ logger . info (f" Total prompt length: { np .sum (prompt_n )} tokens" )
161+ logger . info (f" Average prompt length: { np .mean (prompt_n ):.2f} tokens" )
162+ logger . info (f" Average prompt latency: { np .mean (prompt_ms ):.2f} ms" )
163+ logger . info (f" Average prompt speed: { np .sum (prompt_n ) / (1e-3 * np .sum (prompt_ms )):.2f} tokens/s" )
164+ logger . info (f" Total generated tokens: { token_t .shape [0 ]} " )
165+ logger . info (f" Average generation depth: { depth_sum / token_t .shape [0 ]:.2f} tokens" )
166+ logger . info (f" Average total generation speed: { token_t .shape [0 ] / token_t_last :.2f} tokens/s" )
167+ logger . info (f" Average generation speed per slot: { token_t .shape [0 ] / (parallel * token_t_last ):.2f} tokens/s / slot" )
163168
164169 plt .figure ()
165170 plt .scatter (prompt_n , prompt_ms , s = 10.0 , marker = "." , alpha = 0.25 )
0 commit comments