Skip to content

Commit 4850b52

Browse files
server-bench: external OAI servers, sqlite (#15179)
* server-bench: external OAI servers, sqlite * Update scripts/server-bench.py Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update scripts/server-bench.py Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update scripts/server-bench.py Co-authored-by: Sigbjørn Skjæret <[email protected]> * raise_for_status --------- Co-authored-by: Sigbjørn Skjæret <[email protected]>
1 parent cd6983d commit 4850b52

File tree

1 file changed

+48
-22
lines changed

1 file changed

+48
-22
lines changed

scripts/server-bench.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import json
55
import os
66
import random
7+
import sqlite3
78
import subprocess
89
from time import sleep, time
910
from typing import Optional, Union
@@ -47,6 +48,8 @@ def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
4748

4849

4950
def get_server(path_server: str, path_log: Optional[str]) -> dict:
51+
if path_server.startswith("http://") or path_server.startswith("https://"):
52+
return {"process": None, "address": path_server, "fout": None}
5053
if os.environ.get("LLAMA_ARG_HOST") is None:
5154
logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1")
5255
os.environ["LLAMA_ARG_HOST"] = "127.0.0.1"
@@ -89,15 +92,13 @@ def get_prompt_length(data: dict) -> int:
8992
f"{server_address}/apply-template",
9093
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
9194
)
92-
if response.status_code != 200:
93-
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
95+
response.raise_for_status()
9496
prompt: str = json.loads(response.text)["prompt"]
9597
response = session.post(
9698
f"{server_address}/tokenize",
9799
json={"content": prompt, "add_special": True}
98100
)
99-
if response.status_code != 200:
100-
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
101+
response.raise_for_status()
101102
tokens: list[str] = json.loads(response.text)["tokens"]
102103
return len(tokens)
103104

@@ -107,7 +108,12 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
107108
server_address: str = data["server_address"]
108109

109110
t_submit = time()
110-
if data["synthetic_prompt"]:
111+
if data["external_server"]:
112+
json_data: dict = {
113+
"prompt": data["prompt"], "ignore_eos": True,
114+
"seed": data["seed"], "max_tokens": data["n_predict"], "stream": True}
115+
response = session.post(f"{server_address}/v1/completions", json=json_data, stream=True)
116+
elif data["synthetic_prompt"]:
111117
json_data: dict = {
112118
"prompt": data["prompt"], "ignore_eos": True, "cache_prompt": False,
113119
"seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
@@ -117,34 +123,38 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
117123
f"{server_address}/apply-template",
118124
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
119125
)
120-
if response.status_code != 200:
121-
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
126+
response.raise_for_status()
122127
prompt: str = json.loads(response.text)["prompt"]
123128

124129
json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
125130
response = session.post(f"{server_address}/completion", json=json_data, stream=True)
131+
response.raise_for_status()
126132

133+
lines = []
127134
token_arrival_times: list[float] = []
128135
for line in response.iter_lines(decode_unicode=False):
129136
if not line.startswith(b"data: "):
130137
continue
138+
lines.append(line)
131139
token_arrival_times.append(time())
132140
token_arrival_times = token_arrival_times[:-1]
133-
134-
if response.status_code != 200:
135-
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
141+
if len(lines) > 1 and "timings" in json.loads(lines[-2][6:]):
142+
token_arrival_times = token_arrival_times[:-1]
136143

137144
return (t_submit, token_arrival_times)
138145

139146

140-
def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int, seed_offset: int):
147+
def benchmark(
148+
path_server: str, path_log: Optional[str], path_db: Optional[str], name: Optional[str], prompt_source: str, n_prompts: int,
149+
n_predict: int, n_predict_min: int, seed_offset: int):
150+
external_server: bool = path_server.startswith("http://") or path_server.startswith("https://")
141151
if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
142152
logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
143153
os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
144-
if os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None:
154+
if not external_server and os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None:
145155
logger.info("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999")
146156
os.environ["LLAMA_ARG_N_GPU_LAYERS"] = "999"
147-
if os.environ.get("LLAMA_ARG_FLASH_ATTN") is None:
157+
if not external_server and os.environ.get("LLAMA_ARG_FLASH_ATTN") is None:
148158
logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'")
149159
os.environ["LLAMA_ARG_FLASH_ATTN"] = "true"
150160

@@ -165,7 +175,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
165175
else:
166176
n_predict_min = n_predict
167177

168-
if os.environ.get("LLAMA_ARG_CTX_SIZE") is None:
178+
if not external_server and os.environ.get("LLAMA_ARG_CTX_SIZE") is None:
169179
context_per_slot: int = int(1.05 * (n_predict + (np.max(prompt_n) if synthetic_prompts else 2048)))
170180
context_total: int = context_per_slot * parallel
171181
os.environ["LLAMA_ARG_CTX_SIZE"] = str(context_total)
@@ -176,6 +186,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
176186
try:
177187
server = get_server(path_server, path_log)
178188
server_address: str = server["address"]
189+
assert external_server == (server["process"] is None)
179190

180191
adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore
181192
session = requests.Session()
@@ -188,8 +199,9 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
188199
if seed_offset >= 0:
189200
random.seed(3 * (seed_offset + 1000 * i) + 1)
190201
data.append({
191-
"session": session, "server_address": server_address, "prompt": p, "synthetic_prompt": synthetic_prompts,
192-
"n_predict": random.randint(n_predict_min, n_predict), "seed": (3 * (seed_offset + 1000 * i) + 2) if seed_offset >= 0 else -1})
202+
"session": session, "server_address": server_address, "external_server": external_server, "prompt": p,
203+
"synthetic_prompt": synthetic_prompts, "n_predict": random.randint(n_predict_min, n_predict),
204+
"seed": (3 * (seed_offset + 1000 * i) + 2) if seed_offset >= 0 else -1})
193205

194206
if not synthetic_prompts:
195207
logger.info("Getting the prompt lengths...")
@@ -199,7 +211,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
199211
t0 = time()
200212
results: list[tuple[float, list[float]]] = thread_map(send_prompt, data, max_workers=parallel, chunksize=1)
201213
finally:
202-
if server is not None:
214+
if server is not None and server["process"] is not None:
203215
server["process"].terminate()
204216
server["process"].wait()
205217
if session is not None:
@@ -233,15 +245,24 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
233245
logger.info(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens")
234246
logger.info(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s")
235247
logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
236-
logger.info("")
237-
logger.info(
238-
"The above numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
239-
"particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).")
248+
249+
if path_db is not None:
250+
con = sqlite3.connect(path_db)
251+
cursor = con.cursor()
252+
cursor.execute(
253+
"CREATE TABLE IF NOT EXISTS server_bench"
254+
"(name TEXT, n_parallel INTEGER, prompt_source TEXT, n_prompts INTEGER, "
255+
"n_predict INTEGER, n_predict_min INTEGER, seed_offset INTEGER, runtime REAL);")
256+
cursor.execute(
257+
"INSERT INTO server_bench VALUES (?, ?, ?, ?, ?, ?, ?, ?);",
258+
[name, parallel, prompt_source, n_prompts, n_predict, n_predict_min, seed_offset, token_t_last])
259+
con.commit()
240260

241261
plt.figure()
242262
plt.scatter(prompt_n, 1e3 * prompt_t, s=10.0, marker=".", alpha=0.25)
243263
plt.xlim(0, 1.05e0 * np.max(prompt_n))
244264
plt.ylim(0, 1.05e3 * np.max(prompt_t))
265+
plt.title(name or "")
245266
plt.xlabel("Prompt length [tokens]")
246267
plt.ylabel("Time to first token [ms]")
247268
plt.savefig("prompt_time.png", dpi=240)
@@ -250,6 +271,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
250271
plt.figure()
251272
plt.hist(token_t, np.arange(0, bin_max))
252273
plt.xlim(0, bin_max + 1)
274+
plt.title(name or "")
253275
plt.xlabel("Time [s]")
254276
plt.ylabel("Num. tokens generated per second")
255277
plt.savefig("gen_rate.png", dpi=240)
@@ -259,9 +281,13 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
259281
parser = argparse.ArgumentParser(
260282
description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
261283
"Results are printed to console and visualized as plots (saved to current working directory). "
262-
"To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help).")
284+
"To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help). "
285+
"The reported numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
286+
"particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).")
263287
parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
264288
parser.add_argument("--path_log", type=str, default="server-bench-{port}.log", help="Path to the model to use for the benchmark")
289+
parser.add_argument("--path_db", type=str, default=None, help="Path to an sqlite database to store the benchmark results in")
290+
parser.add_argument("--name", type=str, default=None, help="Name to label plots and database entries with")
265291
parser.add_argument(
266292
"--prompt_source", type=str, default="rng-1024-2048",
267293
help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "

0 commit comments

Comments
 (0)