4
4
import json
5
5
import os
6
6
import random
7
+ import sqlite3
7
8
import subprocess
8
9
from time import sleep , time
9
10
from typing import Optional , Union
@@ -47,6 +48,8 @@ def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
47
48
48
49
49
50
def get_server (path_server : str , path_log : Optional [str ]) -> dict :
51
+ if path_server .startswith ("http://" ) or path_server .startswith ("https://" ):
52
+ return {"process" : None , "address" : path_server , "fout" : None }
50
53
if os .environ .get ("LLAMA_ARG_HOST" ) is None :
51
54
logger .info ("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1" )
52
55
os .environ ["LLAMA_ARG_HOST" ] = "127.0.0.1"
@@ -89,15 +92,13 @@ def get_prompt_length(data: dict) -> int:
89
92
f"{ server_address } /apply-template" ,
90
93
json = {"messages" : [{"role" : "user" , "content" : data ["prompt" ], "stream" : True }]}
91
94
)
92
- if response .status_code != 200 :
93
- raise RuntimeError (f"Server returned status code { response .status_code } : { response .text } " )
95
+ response .raise_for_status ()
94
96
prompt : str = json .loads (response .text )["prompt" ]
95
97
response = session .post (
96
98
f"{ server_address } /tokenize" ,
97
99
json = {"content" : prompt , "add_special" : True }
98
100
)
99
- if response .status_code != 200 :
100
- raise RuntimeError (f"Server returned status code { response .status_code } : { response .text } " )
101
+ response .raise_for_status ()
101
102
tokens : list [str ] = json .loads (response .text )["tokens" ]
102
103
return len (tokens )
103
104
@@ -107,7 +108,12 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
107
108
server_address : str = data ["server_address" ]
108
109
109
110
t_submit = time ()
110
- if data ["synthetic_prompt" ]:
111
+ if data ["external_server" ]:
112
+ json_data : dict = {
113
+ "prompt" : data ["prompt" ], "ignore_eos" : True ,
114
+ "seed" : data ["seed" ], "max_tokens" : data ["n_predict" ], "stream" : True }
115
+ response = session .post (f"{ server_address } /v1/completions" , json = json_data , stream = True )
116
+ elif data ["synthetic_prompt" ]:
111
117
json_data : dict = {
112
118
"prompt" : data ["prompt" ], "ignore_eos" : True , "cache_prompt" : False ,
113
119
"seed" : data ["seed" ], "n_predict" : data ["n_predict" ], "stream" : True }
@@ -117,34 +123,38 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
117
123
f"{ server_address } /apply-template" ,
118
124
json = {"messages" : [{"role" : "user" , "content" : data ["prompt" ], "stream" : True }]}
119
125
)
120
- if response .status_code != 200 :
121
- raise RuntimeError (f"Server returned status code { response .status_code } : { response .text } " )
126
+ response .raise_for_status ()
122
127
prompt : str = json .loads (response .text )["prompt" ]
123
128
124
129
json_data : dict = {"prompt" : prompt , "seed" : data ["seed" ], "n_predict" : data ["n_predict" ], "stream" : True }
125
130
response = session .post (f"{ server_address } /completion" , json = json_data , stream = True )
131
+ response .raise_for_status ()
126
132
133
+ lines = []
127
134
token_arrival_times : list [float ] = []
128
135
for line in response .iter_lines (decode_unicode = False ):
129
136
if not line .startswith (b"data: " ):
130
137
continue
138
+ lines .append (line )
131
139
token_arrival_times .append (time ())
132
140
token_arrival_times = token_arrival_times [:- 1 ]
133
-
134
- if response .status_code != 200 :
135
- raise RuntimeError (f"Server returned status code { response .status_code } : { response .text } " )
141
+ if len (lines ) > 1 and "timings" in json .loads (lines [- 2 ][6 :]):
142
+ token_arrival_times = token_arrival_times [:- 1 ]
136
143
137
144
return (t_submit , token_arrival_times )
138
145
139
146
140
- def benchmark (path_server : str , path_log : Optional [str ], prompt_source : str , n_prompts : int , n_predict : int , n_predict_min : int , seed_offset : int ):
147
+ def benchmark (
148
+ path_server : str , path_log : Optional [str ], path_db : Optional [str ], name : Optional [str ], prompt_source : str , n_prompts : int ,
149
+ n_predict : int , n_predict_min : int , seed_offset : int ):
150
+ external_server : bool = path_server .startswith ("http://" ) or path_server .startswith ("https://" )
141
151
if os .environ .get ("LLAMA_ARG_N_PARALLEL" ) is None :
142
152
logger .info ("LLAMA_ARG_N_PARALLEL not explicitly set, using 32" )
143
153
os .environ ["LLAMA_ARG_N_PARALLEL" ] = "32"
144
- if os .environ .get ("LLAMA_ARG_N_GPU_LAYERS" ) is None :
154
+ if not external_server and os .environ .get ("LLAMA_ARG_N_GPU_LAYERS" ) is None :
145
155
logger .info ("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999" )
146
156
os .environ ["LLAMA_ARG_N_GPU_LAYERS" ] = "999"
147
- if os .environ .get ("LLAMA_ARG_FLASH_ATTN" ) is None :
157
+ if not external_server and os .environ .get ("LLAMA_ARG_FLASH_ATTN" ) is None :
148
158
logger .info ("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'" )
149
159
os .environ ["LLAMA_ARG_FLASH_ATTN" ] = "true"
150
160
@@ -165,7 +175,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
165
175
else :
166
176
n_predict_min = n_predict
167
177
168
- if os .environ .get ("LLAMA_ARG_CTX_SIZE" ) is None :
178
+ if not external_server and os .environ .get ("LLAMA_ARG_CTX_SIZE" ) is None :
169
179
context_per_slot : int = int (1.05 * (n_predict + (np .max (prompt_n ) if synthetic_prompts else 2048 )))
170
180
context_total : int = context_per_slot * parallel
171
181
os .environ ["LLAMA_ARG_CTX_SIZE" ] = str (context_total )
@@ -176,6 +186,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
176
186
try :
177
187
server = get_server (path_server , path_log )
178
188
server_address : str = server ["address" ]
189
+ assert external_server == (server ["process" ] is None )
179
190
180
191
adapter = requests .adapters .HTTPAdapter (pool_connections = parallel , pool_maxsize = parallel ) # type: ignore
181
192
session = requests .Session ()
@@ -188,8 +199,9 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
188
199
if seed_offset >= 0 :
189
200
random .seed (3 * (seed_offset + 1000 * i ) + 1 )
190
201
data .append ({
191
- "session" : session , "server_address" : server_address , "prompt" : p , "synthetic_prompt" : synthetic_prompts ,
192
- "n_predict" : random .randint (n_predict_min , n_predict ), "seed" : (3 * (seed_offset + 1000 * i ) + 2 ) if seed_offset >= 0 else - 1 })
202
+ "session" : session , "server_address" : server_address , "external_server" : external_server , "prompt" : p ,
203
+ "synthetic_prompt" : synthetic_prompts , "n_predict" : random .randint (n_predict_min , n_predict ),
204
+ "seed" : (3 * (seed_offset + 1000 * i ) + 2 ) if seed_offset >= 0 else - 1 })
193
205
194
206
if not synthetic_prompts :
195
207
logger .info ("Getting the prompt lengths..." )
@@ -199,7 +211,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
199
211
t0 = time ()
200
212
results : list [tuple [float , list [float ]]] = thread_map (send_prompt , data , max_workers = parallel , chunksize = 1 )
201
213
finally :
202
- if server is not None :
214
+ if server is not None and server [ "process" ] is not None :
203
215
server ["process" ].terminate ()
204
216
server ["process" ].wait ()
205
217
if session is not None :
@@ -233,15 +245,24 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
233
245
logger .info (f"Average generation depth: { depth_sum / token_t .shape [0 ]:.2f} tokens" )
234
246
logger .info (f"Average total generation speed: { token_t .shape [0 ] / token_t_last :.2f} tokens/s" )
235
247
logger .info (f"Average generation speed per slot: { token_t .shape [0 ] / (parallel * token_t_last ):.2f} tokens/s / slot" )
236
- logger .info ("" )
237
- logger .info (
238
- "The above numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
239
- "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model)." )
248
+
249
+ if path_db is not None :
250
+ con = sqlite3 .connect (path_db )
251
+ cursor = con .cursor ()
252
+ cursor .execute (
253
+ "CREATE TABLE IF NOT EXISTS server_bench"
254
+ "(name TEXT, n_parallel INTEGER, prompt_source TEXT, n_prompts INTEGER, "
255
+ "n_predict INTEGER, n_predict_min INTEGER, seed_offset INTEGER, runtime REAL);" )
256
+ cursor .execute (
257
+ "INSERT INTO server_bench VALUES (?, ?, ?, ?, ?, ?, ?, ?);" ,
258
+ [name , parallel , prompt_source , n_prompts , n_predict , n_predict_min , seed_offset , token_t_last ])
259
+ con .commit ()
240
260
241
261
plt .figure ()
242
262
plt .scatter (prompt_n , 1e3 * prompt_t , s = 10.0 , marker = "." , alpha = 0.25 )
243
263
plt .xlim (0 , 1.05e0 * np .max (prompt_n ))
244
264
plt .ylim (0 , 1.05e3 * np .max (prompt_t ))
265
+ plt .title (name or "" )
245
266
plt .xlabel ("Prompt length [tokens]" )
246
267
plt .ylabel ("Time to first token [ms]" )
247
268
plt .savefig ("prompt_time.png" , dpi = 240 )
@@ -250,6 +271,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
250
271
plt .figure ()
251
272
plt .hist (token_t , np .arange (0 , bin_max ))
252
273
plt .xlim (0 , bin_max + 1 )
274
+ plt .title (name or "" )
253
275
plt .xlabel ("Time [s]" )
254
276
plt .ylabel ("Num. tokens generated per second" )
255
277
plt .savefig ("gen_rate.png" , dpi = 240 )
@@ -259,9 +281,13 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
259
281
parser = argparse .ArgumentParser (
260
282
description = "Tool for benchmarking the throughput of the llama.cpp HTTP server. "
261
283
"Results are printed to console and visualized as plots (saved to current working directory). "
262
- "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help)." )
284
+ "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help). "
285
+ "The reported numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
286
+ "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model)." )
263
287
parser .add_argument ("--path_server" , type = str , default = "llama-server" , help = "Path to the llama.cpp server binary" )
264
288
parser .add_argument ("--path_log" , type = str , default = "server-bench-{port}.log" , help = "Path to the model to use for the benchmark" )
289
+ parser .add_argument ("--path_db" , type = str , default = None , help = "Path to an sqlite database to store the benchmark results in" )
290
+ parser .add_argument ("--name" , type = str , default = None , help = "Name to label plots and database entries with" )
265
291
parser .add_argument (
266
292
"--prompt_source" , type = str , default = "rng-1024-2048" ,
267
293
help = "How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
0 commit comments