22
33import argparse
44import json
5+ import os
6+ import random
57import subprocess
68from time import sleep , time
7- from typing import Optional
9+ from typing import Optional , Union
810
911import datasets
1012import logging
1820logger = logging .getLogger ("server-bench" )
1921
2022
21- def get_prompts (n_prompts : int ) -> list [str ]:
22- logger .info ("Loading MMLU dataset..." )
23- ret = datasets .load_dataset ("cais/mmlu" , "all" )["test" ]["question" ] # type: ignore
23+ def get_prompts_text (dataset_name : str , n_prompts : int ) -> Optional [list [str ]]:
24+ ret = []
25+ if dataset_name .lower () == "mmlu" :
26+ logger .info ("Loading MMLU dataset..." )
27+ ret = datasets .load_dataset ("cais/mmlu" , "all" )["test" ]["question" ] # type: ignore
28+ else :
29+ return None
2430 if n_prompts >= 0 :
2531 ret = ret [:n_prompts ]
2632 return ret
2733
2834
29- def get_server (path_server : str , path_model : str , path_log : Optional [str ], port : int , n_gpu_layers : int , parallel : int , ctx_size : int ) -> dict :
35+ def get_prompt_lengths_rng (n_prompts : int , prompt_length_min : int , prompt_length_max : int ) -> list [int ]:
36+ assert n_prompts >= 0
37+ ret : list [int ] = []
38+ for i in range (n_prompts ):
39+ random .seed (13 * i + 0 )
40+ ret .append (random .randint (prompt_length_min , prompt_length_max ))
41+ return ret
42+
43+
44+ def get_prompts_rng (prompt_lengths : list [int ]) -> list [list [int ]]:
45+ return [[random .randint (100 , 10000 ) for _ in range (pl )] for pl in prompt_lengths ]
46+
47+
48+ def get_server (path_server : str , path_log : Optional [str ]) -> dict :
3049 logger .info ("Starting the llama.cpp server..." )
31- address = f"http://localhost:{ port } "
32-
33- popen_args : list [str ] = [
34- path_server ,
35- "--flash-attn" ,
36- "--n-gpu-layers" , str (n_gpu_layers ),
37- "--parallel" , str (parallel ),
38- "--ctx-size" , str (parallel * ctx_size ),
39- "--model" , path_model ,
40- "--port" , str (port ),
41- "--swa-full" , # FIXME performance bad otherwise
42- # "--attn-streams",
43- ]
44- fout = open ("bench.log" , "w" ) if path_log is not None else subprocess .DEVNULL
45- process = subprocess .Popen (popen_args , stdout = fout , stderr = subprocess .STDOUT )
50+ hostname : str = os .environ .get ("LLAMA_ARG_HOST" , "127.0.0.1" )
51+ port : str = os .environ .get ("LLAMA_ARG_PORT" , "8080" )
52+ address : str = f"http://{ hostname } :{ port } "
53+
54+ fout = open (path_log , "w" ) if path_log is not None else subprocess .DEVNULL
55+ process = subprocess .Popen ([path_server ], stdout = fout , stderr = subprocess .STDOUT )
4656
4757 n_failures : int = 0
4858 while True :
4959 try :
5060 sleep (1.0 )
5161 exit_code = process .poll ()
5262 if exit_code is not None :
53- raise RuntimeError (f"llama.cpp server for { path_model } exited unexpectedly with exit code { exit_code } " )
63+ raise RuntimeError (f"llama.cpp server exited unexpectedly with exit code { exit_code } , see { path_log } " )
5464 response = requests .get (f"{ address } /health" )
5565 if response .status_code == 200 :
5666 break
5767 except requests .ConnectionError :
5868 n_failures += 1
5969 if n_failures >= 10 :
60- raise RuntimeError (f "llama.cpp server for { path_model } is not healthy after 10 seconds" )
70+ raise RuntimeError ("llama.cpp server is not healthy after 10 seconds" )
6171
6272 return {"process" : process , "address" : address , "fout" : fout }
6373
@@ -87,76 +97,116 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
8797 session = data ["session" ]
8898 server_address : str = data ["server_address" ]
8999
90- response = session .post (
91- f"{ server_address } /apply-template" ,
92- json = {"messages" : [{"role" : "user" , "content" : data ["prompt" ], "stream" : True }]}
93- )
94- if response .status_code != 200 :
95- raise RuntimeError (f"Server returned status code { response .status_code } : { response .text } " )
96- prompt : str = json .loads (response .text )["prompt" ]
97-
98- json_data : dict = {"prompt" : prompt , "seed" : data ["seed" ], "n_predict" : data ["n_predict" ], "stream" : True }
99- response = session .post (f"{ server_address } /completion" , json = json_data , stream = True )
100+ t_submit = time ()
101+ if data ["synthetic_prompt" ]:
102+ json_data : dict = {
103+ "prompt" : data ["prompt" ], "ignore_eos" : True , "cache_prompt" : False ,
104+ "seed" : data ["seed" ], "n_predict" : data ["n_predict" ], "stream" : True }
105+ response = session .post (f"{ server_address } /completion" , json = json_data , stream = True )
106+ else :
107+ response = session .post (
108+ f"{ server_address } /apply-template" ,
109+ json = {"messages" : [{"role" : "user" , "content" : data ["prompt" ], "stream" : True }]}
110+ )
111+ if response .status_code != 200 :
112+ raise RuntimeError (f"Server returned status code { response .status_code } : { response .text } " )
113+ prompt : str = json .loads (response .text )["prompt" ]
114+
115+ json_data : dict = {"prompt" : prompt , "seed" : data ["seed" ], "n_predict" : data ["n_predict" ], "stream" : True }
116+ response = session .post (f"{ server_address } /completion" , json = json_data , stream = True )
100117
101- last_valid_line : str = ""
102118 token_arrival_times : list [float ] = []
103- for line in response .iter_lines (decode_unicode = True ):
104- if not line .startswith ("data: " ):
119+ for line in response .iter_lines (decode_unicode = False ):
120+ if not line .startswith (b "data: " ):
105121 continue
106- last_valid_line = line
107122 token_arrival_times .append (time ())
108123 token_arrival_times = token_arrival_times [:- 1 ]
109124
110125 if response .status_code != 200 :
111126 raise RuntimeError (f"Server returned status code { response .status_code } : { response .text } " )
112- timings : dict = json .loads (last_valid_line [6 :])["timings" ]
113127
114- return (timings ["prompt_ms" ], token_arrival_times )
115-
116-
117- def benchmark (path_server : str , path_model : str , path_log : Optional [str ], port : int , n_gpu_layers : int , parallel : int , ctx_size : int , n_prompts : int , n_predict : int ):
118- num_workers : int = parallel + 1
119- prompts : list [str ] = get_prompts (n_prompts )
128+ return (t_submit , token_arrival_times )
129+
130+
131+ def benchmark (path_server : str , path_log : Optional [str ], prompt_source : str , n_prompts : int , n_predict : int , n_predict_min : int ):
132+ if os .environ .get ("LLAMA_ARG_N_PARALLEL" ) is None :
133+ logger .info ("LLAMA_ARG_N_PARALLEL not explicitly set, using 32" )
134+ os .environ ["LLAMA_ARG_N_PARALLEL" ] = "32"
135+ if os .environ .get ("LLAMA_ARG_N_GPU_LAYERS" ) is None :
136+ logger .info ("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999" )
137+ os .environ ["LLAMA_ARG_N_GPU_LAYERS" ] = "999"
138+ if os .environ .get ("LLAMA_ARG_FLASH_ATTN" ) is None :
139+ logger .info ("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'" )
140+ os .environ ["LLAMA_ARG_FLASH_ATTN" ] = "true"
141+
142+ parallel : int = int (os .environ .get ("LLAMA_ARG_N_PARALLEL" , 1 ))
143+ prompts : Union [None , list [str ], list [list [int ]]] = get_prompts_text (prompt_source , n_prompts )
144+ synthetic_prompts : bool = prompts is None
145+ prompt_n = []
146+
147+ if synthetic_prompts :
148+ prompt_source_split : list [str ] = prompt_source .split ("-" )
149+ assert len (prompt_source_split ) == 3
150+ assert prompt_source_split [0 ].lower () == "rng"
151+ prompt_length_min : int = int (prompt_source_split [1 ])
152+ prompt_length_max : int = int (prompt_source_split [2 ])
153+ logger .info ("Generating random prompts..." )
154+ prompt_n = get_prompt_lengths_rng (n_prompts , prompt_length_min , prompt_length_max )
155+ prompts = get_prompts_rng (prompt_n )
156+ else :
157+ n_predict_min = n_predict
158+
159+ if os .environ .get ("LLAMA_ARG_CTX_SIZE" ) is None :
160+ context_per_slot : int = int (1.05 * (n_predict + (np .max (prompt_n ) if synthetic_prompts else 2048 )))
161+ context_total : int = context_per_slot * parallel
162+ os .environ ["LLAMA_ARG_CTX_SIZE" ] = str (context_total )
163+ logger .info (f"LLAMA_ARG_CTX_SIZE not explicitly set, using { context_total } ({ context_per_slot } per slot)." )
120164
121165 server : Optional [dict ] = None
122166 session = None
123167 try :
124- server = get_server (path_server , path_model , path_log , port , n_gpu_layers , parallel , ctx_size )
168+ server = get_server (path_server , path_log )
125169 server_address : str = server ["address" ]
126170
127- adapter = requests .adapters .HTTPAdapter (pool_connections = num_workers , pool_maxsize = num_workers ) # type: ignore
171+ adapter = requests .adapters .HTTPAdapter (pool_connections = parallel , pool_maxsize = parallel ) # type: ignore
128172 session = requests .Session ()
129173 session .mount ("http://" , adapter )
130174 session .mount ("https://" , adapter )
131175
132176 data : list [dict ] = []
133- for i , p in enumerate (prompts ):
134- data .append ({"session" : session , "server_address" : server_address , "prompt" : p , "n_predict" : n_predict , "seed" : i })
135177
136- logger .info ("Getting the prompt lengths..." )
137- prompt_n = [get_prompt_length (d ) for d in data ]
178+ for i , p in enumerate (prompts ): # type: ignore
179+ random .seed (13 * i + 1 )
180+ data .append ({
181+ "session" : session , "server_address" : server_address , "prompt" : p , "synthetic_prompt" : synthetic_prompts ,
182+ "n_predict" : random .randint (n_predict_min , n_predict ), "seed" : 13 * i + 2 })
183+
184+ if not synthetic_prompts :
185+ logger .info ("Getting the prompt lengths..." )
186+ prompt_n = [get_prompt_length (d ) for d in data ]
138187
139188 logger .info ("Starting the benchmark...\n " )
140189 t0 = time ()
141- results : list [tuple [int , list [float ]]] = thread_map (send_prompt , data , max_workers = num_workers , chunksize = 1 )
190+ results : list [tuple [float , list [float ]]] = thread_map (send_prompt , data , max_workers = parallel , chunksize = 1 )
142191 finally :
143192 if server is not None :
144193 server ["process" ].terminate ()
145194 server ["process" ].wait ()
146195 if session is not None :
147196 session .close ()
148197
149- prompt_ms = []
198+ prompt_t = []
150199 token_t = []
151200 depth_sum : int = 0
152- for pn , (pms , tat ) in zip (prompt_n , results ):
153- prompt_ms .append (pms )
201+ for pn , (t_submit , tat ) in zip (prompt_n , results ):
202+ prompt_t .append (tat [ 0 ] - t_submit )
154203 token_t += tat
155204 n_tokens : int = len (tat )
156205 depth_sum += n_tokens * pn
157206 depth_sum += n_tokens * (n_tokens + 1 ) // 2
207+ assert len (token_t ) > 0
158208 prompt_n = np .array (prompt_n , dtype = np .int64 )
159- prompt_ms = np .array (prompt_ms , dtype = np .float64 )
209+ prompt_t = np .array (prompt_t , dtype = np .float64 )
160210 token_t = np .array (token_t , dtype = np .float64 )
161211
162212 token_t -= t0
@@ -167,18 +217,21 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
167217 logger .info (f"Request throughput: { n_prompts / token_t_last :.2f} requests/s = { n_prompts / (token_t_last / 60 ):.2f} requests/min" )
168218 logger .info (f"Total prompt length: { np .sum (prompt_n )} tokens" )
169219 logger .info (f"Average prompt length: { np .mean (prompt_n ):.2f} tokens" )
170- logger .info (f"Average prompt latency: { np .mean (prompt_ms ):.2f} ms" )
171- logger .info (f"Average prompt speed: { np .sum (prompt_n ) / ( 1e-3 * np .sum (prompt_ms ) ):.2f} tokens/s" )
220+ logger .info (f"Average prompt latency: { 1e3 * np .mean (prompt_t ):.2f} ms" )
221+ logger .info (f"Average prompt speed: { np .sum (prompt_n ) / np .sum (prompt_t ):.2f} tokens/s" )
172222 logger .info (f"Total generated tokens: { token_t .shape [0 ]} " )
173223 logger .info (f"Average generation depth: { depth_sum / token_t .shape [0 ]:.2f} tokens" )
174224 logger .info (f"Average total generation speed: { token_t .shape [0 ] / token_t_last :.2f} tokens/s" )
175225 logger .info (f"Average generation speed per slot: { token_t .shape [0 ] / (parallel * token_t_last ):.2f} tokens/s / slot" )
226+ logger .info ("" )
227+ logger .info (
228+ "The above numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
229+ "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model)." )
176230
177231 plt .figure ()
178- plt .scatter (prompt_n , prompt_ms , s = 10.0 , marker = "." , alpha = 0.25 )
179- plt .xlim (0 , 1.05 * np .max (prompt_n ))
180- plt .ylim (0 , 1.05 * np .max (prompt_ms ))
181- plt .title (path_model )
232+ plt .scatter (prompt_n , 1e3 * prompt_t , s = 10.0 , marker = "." , alpha = 0.25 )
233+ plt .xlim (0 , 1.05e0 * np .max (prompt_n ))
234+ plt .ylim (0 , 1.05e3 * np .max (prompt_t ))
182235 plt .xlabel ("Prompt length [tokens]" )
183236 plt .ylabel ("Time to first token [ms]" )
184237 plt .savefig ("prompt_time.png" , dpi = 240 )
@@ -187,7 +240,6 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
187240 plt .figure ()
188241 plt .hist (token_t , np .arange (0 , bin_max ))
189242 plt .xlim (0 , bin_max + 1 )
190- plt .title (path_model )
191243 plt .xlabel ("Time [s]" )
192244 plt .ylabel ("Num. tokens generated per second" )
193245 plt .savefig ("gen_rate.png" , dpi = 240 )
@@ -196,15 +248,18 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
196248if __name__ == "__main__" :
197249 parser = argparse .ArgumentParser (
198250 description = "Tool for benchmarking the throughput of the llama.cpp HTTP server. "
199- "Results are printed to console and visualized as plots (saved to current working directory)." )
251+ "Results are printed to console and visualized as plots (saved to current working directory). "
252+ "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help)." )
200253 parser .add_argument ("--path_server" , type = str , default = "llama-server" , help = "Path to the llama.cpp server binary" )
201- parser .add_argument ("--path_model" , type = str , required = True , help = "Path to the model to use for the benchmark" )
202- parser .add_argument ("--path_log" , type = str , default = None , help = "Path to the model to use for the benchmark" )
203- parser .add_argument ("--port" , type = int , default = 18725 , help = "Port to use for the server during the benchmark" )
204- parser .add_argument ("--n_gpu_layers" , type = int , default = 999 , help = "Number of GPU layers for the server" )
205- parser .add_argument ("--parallel" , type = int , default = 16 , help = "Number of slots for the server" )
206- parser .add_argument ("--ctx_size" , type = int , default = 4096 , help = "Server context size per slot" )
207- parser .add_argument ("--n_prompts" , type = int , default = 1000 , help = "Number of prompts to evaluate" )
254+ parser .add_argument ("--path_log" , type = str , default = "server-bench.log" , help = "Path to the model to use for the benchmark" )
255+ parser .add_argument (
256+ "--prompt_source" , type = str , default = "rng-1024-2048" ,
257+ help = "How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
258+ "rng-MIN-MAX for synthetic prompts with random lengths in the interval [MIN, MAX]" )
259+ parser .add_argument ("--n_prompts" , type = int , default = 100 , help = "Number of prompts to evaluate" )
208260 parser .add_argument ("--n_predict" , type = int , default = 2048 , help = "Max. number of tokens to predict per prompt" )
261+ parser .add_argument (
262+ "--n_predict_min" , type = int , default = 1024 ,
263+ help = "Min. number of tokens to predict per prompt (supported for synthetic prompts only)" )
209264 args = parser .parse_args ()
210265 benchmark (** vars (args ))
0 commit comments