22
33import  argparse 
44import  json 
5+ import  os 
6+ import  random 
57import  subprocess 
68from  time  import  sleep , time 
7- from  typing  import  Optional 
9+ from  typing  import  Optional ,  Union 
810
911import  datasets 
1012import  logging 
1820logger  =  logging .getLogger ("server-bench" )
1921
2022
21- def  get_prompts (n_prompts : int ) ->  list [str ]:
22-     logger .info ("Loading MMLU dataset..." )
23-     ret  =  datasets .load_dataset ("cais/mmlu" , "all" )["test" ]["question" ]  # type: ignore 
23+ def  get_prompts_text (dataset_name : str , n_prompts : int ) ->  Optional [list [str ]]:
24+     ret  =  []
25+     if  dataset_name .lower () ==  "mmlu" :
26+         logger .info ("Loading MMLU dataset..." )
27+         ret  =  datasets .load_dataset ("cais/mmlu" , "all" )["test" ]["question" ]  # type: ignore 
28+     else :
29+         return  None 
2430    if  n_prompts  >=  0 :
2531        ret  =  ret [:n_prompts ]
2632    return  ret 
2733
2834
29- def  get_server (path_server : str , path_model : str , path_log : Optional [str ], port : int , n_gpu_layers : int , parallel : int , ctx_size : int ) ->  dict :
35+ def  get_prompt_lengths_rng (n_prompts : int , prompt_length_min : int , prompt_length_max : int ) ->  list [int ]:
36+     assert  n_prompts  >=  0 
37+     ret : list [int ] =  []
38+     for  i  in  range (n_prompts ):
39+         random .seed (13  *  i  +  0 )
40+         ret .append (random .randint (prompt_length_min , prompt_length_max ))
41+     return  ret 
42+ 
43+ 
44+ def  get_prompts_rng (prompt_lengths : list [int ]) ->  list [list [int ]]:
45+     return  [[random .randint (100 , 10000 ) for  _  in  range (pl )] for  pl  in  prompt_lengths ]
46+ 
47+ 
48+ def  get_server (path_server : str , path_log : Optional [str ]) ->  dict :
3049    logger .info ("Starting the llama.cpp server..." )
31-     address  =  f"http://localhost:{ port }  " 
32- 
33-     popen_args : list [str ] =  [
34-         path_server ,
35-         "--flash-attn" ,
36-         "--n-gpu-layers" , str (n_gpu_layers ),
37-         "--parallel" , str (parallel ),
38-         "--ctx-size" , str (parallel  *  ctx_size ),
39-         "--model" , path_model ,
40-         "--port" , str (port ),
41-         "--swa-full" ,  # FIXME performance bad otherwise 
42-         # "--attn-streams", 
43-     ]
44-     fout  =  open ("bench.log" , "w" ) if  path_log  is  not   None  else  subprocess .DEVNULL 
45-     process  =  subprocess .Popen (popen_args , stdout = fout , stderr = subprocess .STDOUT )
50+     hostname : str  =  os .environ .get ("LLAMA_ARG_HOST" , "127.0.0.1" )
51+     port : str  =  os .environ .get ("LLAMA_ARG_PORT" , "8080" )
52+     address : str  =  f"http://{ hostname }  :{ port }  " 
53+ 
54+     fout  =  open (path_log , "w" ) if  path_log  is  not   None  else  subprocess .DEVNULL 
55+     process  =  subprocess .Popen ([path_server ], stdout = fout , stderr = subprocess .STDOUT )
4656
4757    n_failures : int  =  0 
4858    while  True :
4959        try :
5060            sleep (1.0 )
5161            exit_code  =  process .poll ()
5262            if  exit_code  is  not   None :
53-                 raise  RuntimeError (f"llama.cpp server for  { path_model }   exited unexpectedly with exit code { exit_code }  " )
63+                 raise  RuntimeError (f"llama.cpp server exited unexpectedly with exit code { exit_code } , see  { path_log }  " )
5464            response  =  requests .get (f"{ address }  /health" )
5565            if  response .status_code  ==  200 :
5666                break 
5767        except  requests .ConnectionError :
5868            n_failures  +=  1 
5969            if  n_failures  >=  10 :
60-                 raise  RuntimeError (f "llama.cpp server for  { path_model }   is not healthy after 10 seconds" )
70+                 raise  RuntimeError ("llama.cpp server is not healthy after 10 seconds" )
6171
6272    return  {"process" : process , "address" : address , "fout" : fout }
6373
@@ -87,76 +97,116 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
8797    session  =  data ["session" ]
8898    server_address : str  =  data ["server_address" ]
8999
90-     response  =  session .post (
91-         f"{ server_address }  /apply-template" ,
92-         json = {"messages" : [{"role" : "user" , "content" : data ["prompt" ], "stream" : True }]}
93-     )
94-     if  response .status_code  !=  200 :
95-         raise  RuntimeError (f"Server returned status code { response .status_code }  : { response .text }  " )
96-     prompt : str  =  json .loads (response .text )["prompt" ]
97- 
98-     json_data : dict  =  {"prompt" : prompt , "seed" : data ["seed" ], "n_predict" : data ["n_predict" ], "stream" : True }
99-     response  =  session .post (f"{ server_address }  /completion" , json = json_data , stream = True )
100+     t_submit  =  time ()
101+     if  data ["synthetic_prompt" ]:
102+         json_data : dict  =  {
103+             "prompt" : data ["prompt" ], "ignore_eos" : True , "cache_prompt" : False ,
104+             "seed" : data ["seed" ], "n_predict" : data ["n_predict" ], "stream" : True }
105+         response  =  session .post (f"{ server_address }  /completion" , json = json_data , stream = True )
106+     else :
107+         response  =  session .post (
108+             f"{ server_address }  /apply-template" ,
109+             json = {"messages" : [{"role" : "user" , "content" : data ["prompt" ], "stream" : True }]}
110+         )
111+         if  response .status_code  !=  200 :
112+             raise  RuntimeError (f"Server returned status code { response .status_code }  : { response .text }  " )
113+         prompt : str  =  json .loads (response .text )["prompt" ]
114+ 
115+         json_data : dict  =  {"prompt" : prompt , "seed" : data ["seed" ], "n_predict" : data ["n_predict" ], "stream" : True }
116+         response  =  session .post (f"{ server_address }  /completion" , json = json_data , stream = True )
100117
101-     last_valid_line : str  =  "" 
102118    token_arrival_times : list [float ] =  []
103-     for  line  in  response .iter_lines (decode_unicode = True ):
104-         if  not  line .startswith ("data: " ):
119+     for  line  in  response .iter_lines (decode_unicode = False ):
120+         if  not  line .startswith (b "data: " ):
105121            continue 
106-         last_valid_line  =  line 
107122        token_arrival_times .append (time ())
108123    token_arrival_times  =  token_arrival_times [:- 1 ]
109124
110125    if  response .status_code  !=  200 :
111126        raise  RuntimeError (f"Server returned status code { response .status_code }  : { response .text }  " )
112-     timings : dict  =  json .loads (last_valid_line [6 :])["timings" ]
113127
114-     return  (timings ["prompt_ms" ], token_arrival_times )
115- 
116- 
117- def  benchmark (path_server : str , path_model : str , path_log : Optional [str ], port : int , n_gpu_layers : int , parallel : int , ctx_size : int , n_prompts : int , n_predict : int ):
118-     num_workers : int  =  parallel  +  1 
119-     prompts : list [str ] =  get_prompts (n_prompts )
128+     return  (t_submit , token_arrival_times )
129+ 
130+ 
131+ def  benchmark (path_server : str , path_log : Optional [str ], prompt_source : str , n_prompts : int , n_predict : int , n_predict_min : int ):
132+     if  os .environ .get ("LLAMA_ARG_N_PARALLEL" ) is  None :
133+         logger .info ("LLAMA_ARG_N_PARALLEL not explicitly set, using 32" )
134+         os .environ ["LLAMA_ARG_N_PARALLEL" ] =  "32" 
135+     if  os .environ .get ("LLAMA_ARG_N_GPU_LAYERS" ) is  None :
136+         logger .info ("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999" )
137+         os .environ ["LLAMA_ARG_N_GPU_LAYERS" ] =  "999" 
138+     if  os .environ .get ("LLAMA_ARG_FLASH_ATTN" ) is  None :
139+         logger .info ("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'" )
140+         os .environ ["LLAMA_ARG_FLASH_ATTN" ] =  "true" 
141+ 
142+     parallel : int  =  int (os .environ .get ("LLAMA_ARG_N_PARALLEL" , 1 ))
143+     prompts : Union [None , list [str ], list [int ]] =  get_prompts_text (prompt_source , n_prompts )
144+     synthetic_prompts : bool  =  prompts  is  None 
145+     prompt_n  =  []
146+ 
147+     if  synthetic_prompts :
148+         prompt_source_split : list [str ] =  prompt_source .split ("-" )
149+         assert  len (prompt_source_split ) ==  3 
150+         assert  prompt_source_split [0 ].lower () ==  "rng" 
151+         prompt_length_min : int  =  int (prompt_source_split [1 ])
152+         prompt_length_max : int  =  int (prompt_source_split [2 ])
153+         logger .info ("Generating random prompts..." )
154+         prompt_n  =  get_prompt_lengths_rng (n_prompts , prompt_length_min , prompt_length_max )
155+         prompts  =  get_prompts_rng (prompt_n )
156+     else :
157+         n_predict_min  =  n_predict 
158+ 
159+     if  os .environ .get ("LLAMA_ARG_CTX_SIZE" ) is  None :
160+         context_per_slot : int  =  int (1.05  *  (n_predict  +  (np .max (prompt_n ) if  synthetic_prompts  else  2048 )))
161+         context_total : int  =  context_per_slot  *  parallel 
162+         os .environ ["LLAMA_ARG_CTX_SIZE" ] =  str (context_total )
163+         logger .info (f"LLAMA_ARG_CTX_SIZE not explicitly set, using { context_total }   ({ context_per_slot }   per slot)." )
120164
121165    server : Optional [dict ] =  None 
122166    session  =  None 
123167    try :
124-         server  =  get_server (path_server , path_model ,  path_log ,  port ,  n_gpu_layers ,  parallel ,  ctx_size )
168+         server  =  get_server (path_server , path_log )
125169        server_address : str  =  server ["address" ]
126170
127-         adapter  =  requests .adapters .HTTPAdapter (pool_connections = num_workers , pool_maxsize = num_workers )  # type: ignore 
171+         adapter  =  requests .adapters .HTTPAdapter (pool_connections = parallel , pool_maxsize = parallel )  # type: ignore 
128172        session  =  requests .Session ()
129173        session .mount ("http://" , adapter )
130174        session .mount ("https://" , adapter )
131175
132176        data : list [dict ] =  []
133-         for  i , p  in  enumerate (prompts ):
134-             data .append ({"session" : session , "server_address" : server_address , "prompt" : p , "n_predict" : n_predict , "seed" : i })
135177
136-         logger .info ("Getting the prompt lengths..." )
137-         prompt_n  =  [get_prompt_length (d ) for  d  in  data ]
178+         for  i , p  in  enumerate (prompts ):  # type: ignore 
179+             random .seed (13  *  i  +  1 )
180+             data .append ({
181+                 "session" : session , "server_address" : server_address , "prompt" : p , "synthetic_prompt" : synthetic_prompts ,
182+                 "n_predict" : random .randint (n_predict_min , n_predict ), "seed" : 13  *  i  +  2 })
183+ 
184+         if  not  synthetic_prompts :
185+             logger .info ("Getting the prompt lengths..." )
186+             prompt_n  =  [get_prompt_length (d ) for  d  in  data ]
138187
139188        logger .info ("Starting the benchmark...\n " )
140189        t0  =  time ()
141-         results : list [tuple [int , list [float ]]] =  thread_map (send_prompt , data , max_workers = num_workers , chunksize = 1 )
190+         results : list [tuple [float , list [float ]]] =  thread_map (send_prompt , data , max_workers = parallel , chunksize = 1 )
142191    finally :
143192        if  server  is  not   None :
144193            server ["process" ].terminate ()
145194            server ["process" ].wait ()
146195        if  session  is  not   None :
147196            session .close ()
148197
149-     prompt_ms  =  []
198+     prompt_t  =  []
150199    token_t  =  []
151200    depth_sum : int  =  0 
152-     for  pn , (pms , tat ) in  zip (prompt_n , results ):
153-         prompt_ms .append (pms )
201+     for  pn , (t_submit , tat ) in  zip (prompt_n , results ):
202+         prompt_t .append (tat [ 0 ]  -   t_submit )
154203        token_t  +=  tat 
155204        n_tokens : int  =  len (tat )
156205        depth_sum  +=  n_tokens  *  pn 
157206        depth_sum  +=  n_tokens  *  (n_tokens  +  1 ) //  2 
207+     assert  len (token_t ) >  0 
158208    prompt_n  =  np .array (prompt_n , dtype = np .int64 )
159-     prompt_ms  =  np .array (prompt_ms , dtype = np .float64 )
209+     prompt_t  =  np .array (prompt_t , dtype = np .float64 )
160210    token_t  =  np .array (token_t , dtype = np .float64 )
161211
162212    token_t  -=  t0 
@@ -167,18 +217,21 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
167217    logger .info (f"Request throughput:                { n_prompts  /  token_t_last :.2f}   requests/s = { n_prompts  /  (token_t_last / 60 ):.2f}   requests/min" )
168218    logger .info (f"Total prompt length:               { np .sum (prompt_n )}   tokens" )
169219    logger .info (f"Average prompt length:             { np .mean (prompt_n ):.2f}   tokens" )
170-     logger .info (f"Average prompt latency:            { np .mean (prompt_ms ):.2f}   ms" )
171-     logger .info (f"Average prompt speed:              { np .sum (prompt_n ) /  ( 1e-3   *   np .sum (prompt_ms ) ):.2f}   tokens/s" )
220+     logger .info (f"Average prompt latency:            { 1e3   *   np .mean (prompt_t ):.2f}   ms" )
221+     logger .info (f"Average prompt speed:              { np .sum (prompt_n ) /  np .sum (prompt_t ):.2f}   tokens/s" )
172222    logger .info (f"Total generated tokens:            { token_t .shape [0 ]}  " )
173223    logger .info (f"Average generation depth:          { depth_sum  /  token_t .shape [0 ]:.2f}   tokens" )
174224    logger .info (f"Average total generation speed:    { token_t .shape [0 ] /  token_t_last :.2f}   tokens/s" )
175225    logger .info (f"Average generation speed per slot: { token_t .shape [0 ] /  (parallel  *  token_t_last ):.2f}   tokens/s / slot" )
226+     logger .info ("" )
227+     logger .info (
228+         "The above numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, " 
229+         "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model)." )
176230
177231    plt .figure ()
178-     plt .scatter (prompt_n , prompt_ms , s = 10.0 , marker = "." , alpha = 0.25 )
179-     plt .xlim (0 , 1.05  *  np .max (prompt_n ))
180-     plt .ylim (0 , 1.05  *  np .max (prompt_ms ))
181-     plt .title (path_model )
232+     plt .scatter (prompt_n , 1e3  *  prompt_t , s = 10.0 , marker = "." , alpha = 0.25 )
233+     plt .xlim (0 , 1.05e0  *  np .max (prompt_n ))
234+     plt .ylim (0 , 1.05e3  *  np .max (prompt_t ))
182235    plt .xlabel ("Prompt length [tokens]" )
183236    plt .ylabel ("Time to first token [ms]" )
184237    plt .savefig ("prompt_time.png" , dpi = 240 )
@@ -187,7 +240,6 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
187240    plt .figure ()
188241    plt .hist (token_t , np .arange (0 , bin_max ))
189242    plt .xlim (0 , bin_max  +  1 )
190-     plt .title (path_model )
191243    plt .xlabel ("Time [s]" )
192244    plt .ylabel ("Num. tokens generated per second" )
193245    plt .savefig ("gen_rate.png" , dpi = 240 )
@@ -196,15 +248,18 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
196248if  __name__  ==  "__main__" :
197249    parser  =  argparse .ArgumentParser (
198250        description = "Tool for benchmarking the throughput of the llama.cpp HTTP server. " 
199-         "Results are printed to console and visualized as plots (saved to current working directory)." )
251+         "Results are printed to console and visualized as plots (saved to current working directory). " 
252+         "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help)." )
200253    parser .add_argument ("--path_server" , type = str , default = "llama-server" , help = "Path to the llama.cpp server binary" )
201-     parser .add_argument ("--path_model" , type = str , required = True , help = "Path to the model to use for the benchmark" )
202-     parser .add_argument ("--path_log" , type = str , default = None , help = "Path to the model to use for the benchmark" )
203-     parser .add_argument ("--port" , type = int , default = 18725 , help = "Port to use for the server during the benchmark" )
204-     parser .add_argument ("--n_gpu_layers" , type = int , default = 999 , help = "Number of GPU layers for the server" )
205-     parser .add_argument ("--parallel" , type = int , default = 16 , help = "Number of slots for the server" )
206-     parser .add_argument ("--ctx_size" , type = int , default = 4096 , help = "Server context size per slot" )
207-     parser .add_argument ("--n_prompts" , type = int , default = 1000 , help = "Number of prompts to evaluate" )
254+     parser .add_argument ("--path_log" , type = str , default = "server-bench.log" , help = "Path to the model to use for the benchmark" )
255+     parser .add_argument (
256+         "--prompt_source" , type = str , default = "rng-1024-2048" ,
257+         help = "How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or " 
258+         "rng-MIN-MAX for synthetic prompts with random lengths in the interval [MIN, MAX]" )
259+     parser .add_argument ("--n_prompts" , type = int , default = 100 , help = "Number of prompts to evaluate" )
208260    parser .add_argument ("--n_predict" , type = int , default = 2048 , help = "Max. number of tokens to predict per prompt" )
261+     parser .add_argument (
262+         "--n_predict_min" , type = int , default = 1024 ,
263+         help = "Min. number of tokens to predict per prompt (supported for synthetic prompts only)" )
209264    args  =  parser .parse_args ()
210265    benchmark (** vars (args ))
0 commit comments