44import json
55import os
66import random
7+ import sqlite3
78import subprocess
89from time import sleep , time
910from typing import Optional , Union
@@ -47,6 +48,8 @@ def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
4748
4849
4950def get_server (path_server : str , path_log : Optional [str ]) -> dict :
51+ if path_server .startswith ("http://" ) or path_server .startswith ("https://" ):
52+ return {"process" : None , "address" : path_server , "fout" : None }
5053 if os .environ .get ("LLAMA_ARG_HOST" ) is None :
5154 logger .info ("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1" )
5255 os .environ ["LLAMA_ARG_HOST" ] = "127.0.0.1"
@@ -89,15 +92,13 @@ def get_prompt_length(data: dict) -> int:
8992 f"{ server_address } /apply-template" ,
9093 json = {"messages" : [{"role" : "user" , "content" : data ["prompt" ], "stream" : True }]}
9194 )
92- if response .status_code != 200 :
93- raise RuntimeError (f"Server returned status code { response .status_code } : { response .text } " )
95+ response .raise_for_status ()
9496 prompt : str = json .loads (response .text )["prompt" ]
9597 response = session .post (
9698 f"{ server_address } /tokenize" ,
9799 json = {"content" : prompt , "add_special" : True }
98100 )
99- if response .status_code != 200 :
100- raise RuntimeError (f"Server returned status code { response .status_code } : { response .text } " )
101+ response .raise_for_status ()
101102 tokens : list [str ] = json .loads (response .text )["tokens" ]
102103 return len (tokens )
103104
@@ -107,7 +108,12 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
107108 server_address : str = data ["server_address" ]
108109
109110 t_submit = time ()
110- if data ["synthetic_prompt" ]:
111+ if data ["external_server" ]:
112+ json_data : dict = {
113+ "prompt" : data ["prompt" ], "ignore_eos" : True ,
114+ "seed" : data ["seed" ], "max_tokens" : data ["n_predict" ], "stream" : True }
115+ response = session .post (f"{ server_address } /v1/completions" , json = json_data , stream = True )
116+ elif data ["synthetic_prompt" ]:
111117 json_data : dict = {
112118 "prompt" : data ["prompt" ], "ignore_eos" : True , "cache_prompt" : False ,
113119 "seed" : data ["seed" ], "n_predict" : data ["n_predict" ], "stream" : True }
@@ -117,34 +123,38 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
117123 f"{ server_address } /apply-template" ,
118124 json = {"messages" : [{"role" : "user" , "content" : data ["prompt" ], "stream" : True }]}
119125 )
120- if response .status_code != 200 :
121- raise RuntimeError (f"Server returned status code { response .status_code } : { response .text } " )
126+ response .raise_for_status ()
122127 prompt : str = json .loads (response .text )["prompt" ]
123128
124129 json_data : dict = {"prompt" : prompt , "seed" : data ["seed" ], "n_predict" : data ["n_predict" ], "stream" : True }
125130 response = session .post (f"{ server_address } /completion" , json = json_data , stream = True )
131+ response .raise_for_status ()
126132
133+ lines = []
127134 token_arrival_times : list [float ] = []
128135 for line in response .iter_lines (decode_unicode = False ):
129136 if not line .startswith (b"data: " ):
130137 continue
138+ lines .append (line )
131139 token_arrival_times .append (time ())
132140 token_arrival_times = token_arrival_times [:- 1 ]
133-
134- if response .status_code != 200 :
135- raise RuntimeError (f"Server returned status code { response .status_code } : { response .text } " )
141+ if len (lines ) > 1 and "timings" in json .loads (lines [- 2 ][6 :]):
142+ token_arrival_times = token_arrival_times [:- 1 ]
136143
137144 return (t_submit , token_arrival_times )
138145
139146
140- def benchmark (path_server : str , path_log : Optional [str ], prompt_source : str , n_prompts : int , n_predict : int , n_predict_min : int , seed_offset : int ):
147+ def benchmark (
148+ path_server : str , path_log : Optional [str ], path_db : Optional [str ], name : Optional [str ], prompt_source : str , n_prompts : int ,
149+ n_predict : int , n_predict_min : int , seed_offset : int ):
150+ external_server : bool = path_server .startswith ("http://" ) or path_server .startswith ("https://" )
141151 if os .environ .get ("LLAMA_ARG_N_PARALLEL" ) is None :
142152 logger .info ("LLAMA_ARG_N_PARALLEL not explicitly set, using 32" )
143153 os .environ ["LLAMA_ARG_N_PARALLEL" ] = "32"
144- if os .environ .get ("LLAMA_ARG_N_GPU_LAYERS" ) is None :
154+ if not external_server and os .environ .get ("LLAMA_ARG_N_GPU_LAYERS" ) is None :
145155 logger .info ("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999" )
146156 os .environ ["LLAMA_ARG_N_GPU_LAYERS" ] = "999"
147- if os .environ .get ("LLAMA_ARG_FLASH_ATTN" ) is None :
157+ if not external_server and os .environ .get ("LLAMA_ARG_FLASH_ATTN" ) is None :
148158 logger .info ("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'" )
149159 os .environ ["LLAMA_ARG_FLASH_ATTN" ] = "true"
150160
@@ -165,7 +175,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
165175 else :
166176 n_predict_min = n_predict
167177
168- if os .environ .get ("LLAMA_ARG_CTX_SIZE" ) is None :
178+ if not external_server and os .environ .get ("LLAMA_ARG_CTX_SIZE" ) is None :
169179 context_per_slot : int = int (1.05 * (n_predict + (np .max (prompt_n ) if synthetic_prompts else 2048 )))
170180 context_total : int = context_per_slot * parallel
171181 os .environ ["LLAMA_ARG_CTX_SIZE" ] = str (context_total )
@@ -176,6 +186,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
176186 try :
177187 server = get_server (path_server , path_log )
178188 server_address : str = server ["address" ]
189+ assert external_server == (server ["process" ] is None )
179190
180191 adapter = requests .adapters .HTTPAdapter (pool_connections = parallel , pool_maxsize = parallel ) # type: ignore
181192 session = requests .Session ()
@@ -188,8 +199,9 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
188199 if seed_offset >= 0 :
189200 random .seed (3 * (seed_offset + 1000 * i ) + 1 )
190201 data .append ({
191- "session" : session , "server_address" : server_address , "prompt" : p , "synthetic_prompt" : synthetic_prompts ,
192- "n_predict" : random .randint (n_predict_min , n_predict ), "seed" : (3 * (seed_offset + 1000 * i ) + 2 ) if seed_offset >= 0 else - 1 })
202+ "session" : session , "server_address" : server_address , "external_server" : external_server , "prompt" : p ,
203+ "synthetic_prompt" : synthetic_prompts , "n_predict" : random .randint (n_predict_min , n_predict ),
204+ "seed" : (3 * (seed_offset + 1000 * i ) + 2 ) if seed_offset >= 0 else - 1 })
193205
194206 if not synthetic_prompts :
195207 logger .info ("Getting the prompt lengths..." )
@@ -199,7 +211,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
199211 t0 = time ()
200212 results : list [tuple [float , list [float ]]] = thread_map (send_prompt , data , max_workers = parallel , chunksize = 1 )
201213 finally :
202- if server is not None :
214+ if server is not None and server [ "process" ] is not None :
203215 server ["process" ].terminate ()
204216 server ["process" ].wait ()
205217 if session is not None :
@@ -233,15 +245,24 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
233245 logger .info (f"Average generation depth: { depth_sum / token_t .shape [0 ]:.2f} tokens" )
234246 logger .info (f"Average total generation speed: { token_t .shape [0 ] / token_t_last :.2f} tokens/s" )
235247 logger .info (f"Average generation speed per slot: { token_t .shape [0 ] / (parallel * token_t_last ):.2f} tokens/s / slot" )
236- logger .info ("" )
237- logger .info (
238- "The above numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
239- "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model)." )
248+
249+ if path_db is not None :
250+ con = sqlite3 .connect (path_db )
251+ cursor = con .cursor ()
252+ cursor .execute (
253+ "CREATE TABLE IF NOT EXISTS server_bench"
254+ "(name TEXT, n_parallel INTEGER, prompt_source TEXT, n_prompts INTEGER, "
255+ "n_predict INTEGER, n_predict_min INTEGER, seed_offset INTEGER, runtime REAL);" )
256+ cursor .execute (
257+ "INSERT INTO server_bench VALUES (?, ?, ?, ?, ?, ?, ?, ?);" ,
258+ [name , parallel , prompt_source , n_prompts , n_predict , n_predict_min , seed_offset , token_t_last ])
259+ con .commit ()
240260
241261 plt .figure ()
242262 plt .scatter (prompt_n , 1e3 * prompt_t , s = 10.0 , marker = "." , alpha = 0.25 )
243263 plt .xlim (0 , 1.05e0 * np .max (prompt_n ))
244264 plt .ylim (0 , 1.05e3 * np .max (prompt_t ))
265+ plt .title (name or "" )
245266 plt .xlabel ("Prompt length [tokens]" )
246267 plt .ylabel ("Time to first token [ms]" )
247268 plt .savefig ("prompt_time.png" , dpi = 240 )
@@ -250,6 +271,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
250271 plt .figure ()
251272 plt .hist (token_t , np .arange (0 , bin_max ))
252273 plt .xlim (0 , bin_max + 1 )
274+ plt .title (name or "" )
253275 plt .xlabel ("Time [s]" )
254276 plt .ylabel ("Num. tokens generated per second" )
255277 plt .savefig ("gen_rate.png" , dpi = 240 )
@@ -259,9 +281,13 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
259281 parser = argparse .ArgumentParser (
260282 description = "Tool for benchmarking the throughput of the llama.cpp HTTP server. "
261283 "Results are printed to console and visualized as plots (saved to current working directory). "
262- "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help)." )
284+ "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help). "
285+ "The reported numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
286+ "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model)." )
263287 parser .add_argument ("--path_server" , type = str , default = "llama-server" , help = "Path to the llama.cpp server binary" )
264288 parser .add_argument ("--path_log" , type = str , default = "server-bench-{port}.log" , help = "Path to the model to use for the benchmark" )
289+ parser .add_argument ("--path_db" , type = str , default = None , help = "Path to an sqlite database to store the benchmark results in" )
290+ parser .add_argument ("--name" , type = str , default = None , help = "Name to label plots and database entries with" )
265291 parser .add_argument (
266292 "--prompt_source" , type = str , default = "rng-1024-2048" ,
267293 help = "How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
0 commit comments