22
33import argparse
44import json
5+ import os
6+ import random
57import subprocess
68from time import sleep , time
79from typing import Optional
1820logger = logging .getLogger ("server-bench" )
1921
2022
21- def get_prompts (n_prompts : int ) -> list [str ]:
22- logger .info ("Loading MMLU dataset..." )
23- ret = datasets .load_dataset ("cais/mmlu" , "all" )["test" ]["question" ] # type: ignore
23+ def get_prompts_text (dataset_name : str , n_prompts : int ) -> Optional [list [str ]]:
24+ ret = []
25+ if dataset_name .lower () == "mmlu" :
26+ logger .info ("Loading MMLU dataset..." )
27+ ret = datasets .load_dataset ("cais/mmlu" , "all" )["test" ]["question" ] # type: ignore
28+ else :
29+ return None
2430 if n_prompts >= 0 :
2531 ret = ret [:n_prompts ]
2632 return ret
2733
2834
29- def get_server (path_server : str , path_model : str , path_log : Optional [str ], port : int , n_gpu_layers : int , parallel : int , ctx_size : int ) -> dict :
35+ def get_prompt_lengths_rng (n_prompts : int , prompt_length_min : int , prompt_length_max : int ) -> list [int ]:
36+ assert n_prompts >= 0
37+ ret : list [list [int ]] = []
38+ for i in range (n_prompts ):
39+ random .seed (13 * i + 0 )
40+ ret .append (random .randint (prompt_length_min , prompt_length_max ))
41+ return ret
42+
43+
44+ def get_prompts_rng (prompt_lengths : list [int ]) -> list [list [int ]]:
45+ return [[random .randint (100 , 10000 ) for _ in range (pl )] for pl in prompt_lengths ]
46+
47+
48+ def get_server (path_server : str , path_log : Optional [str ]) -> dict :
3049 logger .info ("Starting the llama.cpp server..." )
31- address = f"http://localhost:{ port } "
32-
33- popen_args : list [str ] = [
34- path_server ,
35- "--flash-attn" ,
36- "--n-gpu-layers" , str (n_gpu_layers ),
37- "--parallel" , str (parallel ),
38- "--ctx-size" , str (parallel * ctx_size ),
39- "--model" , path_model ,
40- "--port" , str (port ),
41- "--swa-full" , # FIXME performance bad otherwise
42- # "--attn-streams",
43- ]
44- fout = open ("bench.log" , "w" ) if path_log is not None else subprocess .DEVNULL
45- process = subprocess .Popen (popen_args , stdout = fout , stderr = subprocess .STDOUT )
50+ hostname : str = os .environ .get ("LLAMA_ARG_HOST" , "127.0.0.1" )
51+ port : str = os .environ .get ("LLAMA_ARG_PORT" , "8080" )
52+ address : str = f"http://{ hostname } :{ port } "
53+
54+ fout = open (path_log , "w" ) if path_log is not None else subprocess .DEVNULL
55+ process = subprocess .Popen ([path_server ], stdout = fout , stderr = subprocess .STDOUT )
4656
4757 n_failures : int = 0
4858 while True :
4959 try :
5060 sleep (1.0 )
5161 exit_code = process .poll ()
5262 if exit_code is not None :
53- raise RuntimeError (f"llama.cpp server for { path_model } exited unexpectedly with exit code { exit_code } " )
63+ raise RuntimeError (f"llama.cpp server exited unexpectedly with exit code { exit_code } , see { path_log } " )
5464 response = requests .get (f"{ address } /health" )
5565 if response .status_code == 200 :
5666 break
5767 except requests .ConnectionError :
5868 n_failures += 1
5969 if n_failures >= 10 :
60- raise RuntimeError (f"llama.cpp server for { path_model } is not healthy after 10 seconds" )
70+ raise RuntimeError (f"llama.cpp server is not healthy after 10 seconds" )
6171
6272 return {"process" : process , "address" : address , "fout" : fout }
6373
@@ -87,76 +97,118 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
8797 session = data ["session" ]
8898 server_address : str = data ["server_address" ]
8999
90- response = session .post (
91- f"{ server_address } /apply-template" ,
92- json = {"messages" : [{"role" : "user" , "content" : data ["prompt" ], "stream" : True }]}
93- )
94- if response .status_code != 200 :
95- raise RuntimeError (f"Server returned status code { response .status_code } : { response .text } " )
96- prompt : str = json .loads (response .text )["prompt" ]
97-
98- json_data : dict = {"prompt" : prompt , "seed" : data ["seed" ], "n_predict" : data ["n_predict" ], "stream" : True }
99- response = session .post (f"{ server_address } /completion" , json = json_data , stream = True )
100+ t_submit = time ()
101+ if data ["synthetic_prompt" ]:
102+ json_data : dict = {
103+ "prompt" : data ["prompt" ], "ignore_eos" : True , "cache_prompt" : False ,
104+ "seed" : data ["seed" ], "n_predict" : data ["n_predict" ], "stream" : True }
105+ response = session .post (f"{ server_address } /completion" , json = json_data , stream = True )
106+ else :
107+ response = session .post (
108+ f"{ server_address } /apply-template" ,
109+ json = {"messages" : [{"role" : "user" , "content" : data ["prompt" ], "stream" : True }]}
110+ )
111+ if response .status_code != 200 :
112+ raise RuntimeError (f"Server returned status code { response .status_code } : { response .text } " )
113+ prompt : str = json .loads (response .text )["prompt" ]
114+
115+ json_data : dict = {"prompt" : prompt , "seed" : data ["seed" ], "n_predict" : data ["n_predict" ], "stream" : True }
116+ response = session .post (f"{ server_address } /completion" , json = json_data , stream = True )
100117
101118 last_valid_line : str = ""
102119 token_arrival_times : list [float ] = []
103- for line in response .iter_lines (decode_unicode = True ):
104- if not line .startswith ("data: " ):
120+ for line in response .iter_lines (decode_unicode = False ):
121+ if not line .startswith (b "data: " ):
105122 continue
106123 last_valid_line = line
107124 token_arrival_times .append (time ())
108125 token_arrival_times = token_arrival_times [:- 1 ]
109126
110127 if response .status_code != 200 :
111128 raise RuntimeError (f"Server returned status code { response .status_code } : { response .text } " )
112- timings : dict = json .loads (last_valid_line [6 :])["timings" ]
113129
114- return (timings ["prompt_ms" ], token_arrival_times )
115-
116-
117- def benchmark (path_server : str , path_model : str , path_log : Optional [str ], port : int , n_gpu_layers : int , parallel : int , ctx_size : int , n_prompts : int , n_predict : int ):
118- num_workers : int = parallel + 1
119- prompts : list [str ] = get_prompts (n_prompts )
130+ return (t_submit , token_arrival_times )
131+
132+
133+ def benchmark (path_server : str , path_log : Optional [str ], prompt_source : str , n_prompts : int , n_predict : int , n_predict_min : int ):
134+ if os .environ .get ("LLAMA_ARG_N_PARALLEL" ) is None :
135+ logger .info ("LLAMA_ARG_N_PARALLEL not explicitly set, using 32" )
136+ os .environ ["LLAMA_ARG_N_PARALLEL" ] = "32"
137+ if os .environ .get ("LLAMA_ARG_N_GPU_LAYERS" ) is None :
138+ logger .info ("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999" )
139+ os .environ ["LLAMA_ARG_N_GPU_LAYERS" ] = "999"
140+ if os .environ .get ("LLAMA_ARG_FLASH_ATTN" ) is None :
141+ logger .info ("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'" )
142+ os .environ ["LLAMA_ARG_FLASH_ATTN" ] = "true"
143+
144+ parallel : int = int (os .environ .get ("LLAMA_ARG_N_PARALLEL" , 1 ))
145+ prompts : Optional [list [str ]] = get_prompts_text (prompt_source , n_prompts )
146+ synthetic_prompts : bool = prompts is None
147+ prompt_n = []
148+
149+ if synthetic_prompts :
150+ prompt_source_split : list [str ] = prompt_source .split ("-" )
151+ assert len (prompt_source_split ) == 3
152+ assert prompt_source_split [0 ].lower () == "rng"
153+ prompt_length_min : int = int (prompt_source_split [1 ])
154+ prompt_length_max : int = int (prompt_source_split [2 ])
155+ logger .info ("Generating random prompts..." )
156+ prompt_n = get_prompt_lengths_rng (n_prompts , prompt_length_min , prompt_length_max )
157+ prompts = get_prompts_rng (prompt_n )
158+ else :
159+ n_predict_min = n_predict
160+
161+ if os .environ .get ("LLAMA_ARG_CTX_SIZE" ) is None :
162+ context_per_slot : int = int (1.05 * (n_predict + (np .max (prompt_n ) if synthetic_prompts else 2048 )))
163+ context_total : int = context_per_slot * parallel
164+ os .environ ["LLAMA_ARG_CTX_SIZE" ] = str (context_total )
165+ logger .info (f"LLAMA_ARG_CTX_SIZE not explicitly set, using { context_total } ({ context_per_slot } per slot)." )
120166
121167 server : Optional [dict ] = None
122168 session = None
123169 try :
124- server = get_server (path_server , path_model , path_log , port , n_gpu_layers , parallel , ctx_size )
170+ server = get_server (path_server , path_log )
125171 server_address : str = server ["address" ]
126172
127- adapter = requests .adapters .HTTPAdapter (pool_connections = num_workers , pool_maxsize = num_workers ) # type: ignore
173+ adapter = requests .adapters .HTTPAdapter (pool_connections = parallel , pool_maxsize = parallel ) # type: ignore
128174 session = requests .Session ()
129175 session .mount ("http://" , adapter )
130176 session .mount ("https://" , adapter )
131177
132178 data : list [dict ] = []
179+
133180 for i , p in enumerate (prompts ):
134- data .append ({"session" : session , "server_address" : server_address , "prompt" : p , "n_predict" : n_predict , "seed" : i })
181+ random .seed (13 * i + 1 )
182+ data .append ({
183+ "session" : session , "server_address" : server_address , "prompt" : p , "synthetic_prompt" : synthetic_prompts ,
184+ "n_predict" : random .randint (n_predict_min , n_predict ), "seed" : 13 * i + 2 })
135185
136- logger .info ("Getting the prompt lengths..." )
137- prompt_n = [get_prompt_length (d ) for d in data ]
186+ if not synthetic_prompts :
187+ logger .info ("Getting the prompt lengths..." )
188+ prompt_n = [get_prompt_length (d ) for d in data ]
138189
139190 logger .info ("Starting the benchmark...\n " )
140191 t0 = time ()
141- results : list [tuple [int , list [float ]]] = thread_map (send_prompt , data , max_workers = num_workers , chunksize = 1 )
192+ results : list [tuple [float , list [float ]]] = thread_map (send_prompt , data , max_workers = parallel , chunksize = 1 )
142193 finally :
143194 if server is not None :
144195 server ["process" ].terminate ()
145196 server ["process" ].wait ()
146197 if session is not None :
147198 session .close ()
148199
149- prompt_ms = []
200+ prompt_t = []
150201 token_t = []
151202 depth_sum : int = 0
152- for pn , (pms , tat ) in zip (prompt_n , results ):
153- prompt_ms .append (pms )
203+ for pn , (t_submit , tat ) in zip (prompt_n , results ):
204+ prompt_t .append (tat [ 0 ] - t_submit )
154205 token_t += tat
155206 n_tokens : int = len (tat )
156207 depth_sum += n_tokens * pn
157208 depth_sum += n_tokens * (n_tokens + 1 ) // 2
209+ assert len (token_t ) > 0
158210 prompt_n = np .array (prompt_n , dtype = np .int64 )
159- prompt_ms = np .array (prompt_ms , dtype = np .float64 )
211+ prompt_t = np .array (prompt_t , dtype = np .float64 )
160212 token_t = np .array (token_t , dtype = np .float64 )
161213
162214 token_t -= t0
@@ -167,18 +219,21 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
167219 logger .info (f"Request throughput: { n_prompts / token_t_last :.2f} requests/s = { n_prompts / (token_t_last / 60 ):.2f} requests/min" )
168220 logger .info (f"Total prompt length: { np .sum (prompt_n )} tokens" )
169221 logger .info (f"Average prompt length: { np .mean (prompt_n ):.2f} tokens" )
170- logger .info (f"Average prompt latency: { np .mean (prompt_ms ):.2f} ms" )
171- logger .info (f"Average prompt speed: { np .sum (prompt_n ) / ( 1e-3 * np .sum (prompt_ms ) ):.2f} tokens/s" )
222+ logger .info (f"Average prompt latency: { 1e3 * np .mean (prompt_t ):.2f} ms" )
223+ logger .info (f"Average prompt speed: { np .sum (prompt_n ) / np .sum (prompt_t ):.2f} tokens/s" )
172224 logger .info (f"Total generated tokens: { token_t .shape [0 ]} " )
173225 logger .info (f"Average generation depth: { depth_sum / token_t .shape [0 ]:.2f} tokens" )
174226 logger .info (f"Average total generation speed: { token_t .shape [0 ] / token_t_last :.2f} tokens/s" )
175227 logger .info (f"Average generation speed per slot: { token_t .shape [0 ] / (parallel * token_t_last ):.2f} tokens/s / slot" )
228+ logger .info ("" )
229+ logger .info (
230+ "The above numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
231+ "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model)." )
176232
177233 plt .figure ()
178- plt .scatter (prompt_n , prompt_ms , s = 10.0 , marker = "." , alpha = 0.25 )
179- plt .xlim (0 , 1.05 * np .max (prompt_n ))
180- plt .ylim (0 , 1.05 * np .max (prompt_ms ))
181- plt .title (path_model )
234+ plt .scatter (prompt_n , 1e3 * prompt_t , s = 10.0 , marker = "." , alpha = 0.25 )
235+ plt .xlim (0 , 1.05e0 * np .max (prompt_n ))
236+ plt .ylim (0 , 1.05e3 * np .max (prompt_t ))
182237 plt .xlabel ("Prompt length [tokens]" )
183238 plt .ylabel ("Time to first token [ms]" )
184239 plt .savefig ("prompt_time.png" , dpi = 240 )
@@ -187,7 +242,6 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
187242 plt .figure ()
188243 plt .hist (token_t , np .arange (0 , bin_max ))
189244 plt .xlim (0 , bin_max + 1 )
190- plt .title (path_model )
191245 plt .xlabel ("Time [s]" )
192246 plt .ylabel ("Num. tokens generated per second" )
193247 plt .savefig ("gen_rate.png" , dpi = 240 )
@@ -196,15 +250,18 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
196250if __name__ == "__main__" :
197251 parser = argparse .ArgumentParser (
198252 description = "Tool for benchmarking the throughput of the llama.cpp HTTP server. "
199- "Results are printed to console and visualized as plots (saved to current working directory)." )
253+ "Results are printed to console and visualized as plots (saved to current working directory). "
254+ "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help)." )
200255 parser .add_argument ("--path_server" , type = str , default = "llama-server" , help = "Path to the llama.cpp server binary" )
201- parser .add_argument ("--path_model" , type = str , required = True , help = "Path to the model to use for the benchmark" )
202- parser .add_argument ("--path_log" , type = str , default = None , help = "Path to the model to use for the benchmark" )
203- parser .add_argument ("--port" , type = int , default = 18725 , help = "Port to use for the server during the benchmark" )
204- parser .add_argument ("--n_gpu_layers" , type = int , default = 999 , help = "Number of GPU layers for the server" )
205- parser .add_argument ("--parallel" , type = int , default = 16 , help = "Number of slots for the server" )
206- parser .add_argument ("--ctx_size" , type = int , default = 4096 , help = "Server context size per slot" )
207- parser .add_argument ("--n_prompts" , type = int , default = 1000 , help = "Number of prompts to evaluate" )
256+ parser .add_argument ("--path_log" , type = str , default = "server-bench.log" , help = "Path to the model to use for the benchmark" )
257+ parser .add_argument (
258+ "--prompt_source" , type = str , default = "rng-1024-2048" ,
259+ help = "How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
260+ "rng-MIN-MAX for synthetic prompts with random lengths in the interval [MIN, MAX]" )
261+ parser .add_argument ("--n_prompts" , type = int , default = 100 , help = "Number of prompts to evaluate" )
208262 parser .add_argument ("--n_predict" , type = int , default = 2048 , help = "Max. number of tokens to predict per prompt" )
263+ parser .add_argument (
264+ "--n_predict_min" , type = int , default = 1024 ,
265+ help = "Min. number of tokens to predict per prompt (supported for synthetic prompts only)" )
209266 args = parser .parse_args ()
210267 benchmark (** vars (args ))
0 commit comments