Skip to content

Commit 0f02ab8

Browse files
scripts: synthetic prompt mode for server-bench.py
1 parent 494c589 commit 0f02ab8

File tree

2 files changed

+123
-66
lines changed

2 files changed

+123
-66
lines changed

scripts/server-bench.py

100644100755
Lines changed: 122 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import argparse
44
import json
5+
import os
6+
import random
57
import subprocess
68
from time import sleep, time
79
from typing import Optional
@@ -18,46 +20,54 @@
1820
logger = logging.getLogger("server-bench")
1921

2022

21-
def get_prompts(n_prompts: int) -> list[str]:
22-
logger.info("Loading MMLU dataset...")
23-
ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] # type: ignore
23+
def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]:
24+
ret = []
25+
if dataset_name.lower() == "mmlu":
26+
logger.info("Loading MMLU dataset...")
27+
ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] # type: ignore
28+
else:
29+
return None
2430
if n_prompts >= 0:
2531
ret = ret[:n_prompts]
2632
return ret
2733

2834

29-
def get_server(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int) -> dict:
35+
def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int) -> list[int]:
36+
assert n_prompts >= 0
37+
ret: list[list[int]] = []
38+
for i in range(n_prompts):
39+
random.seed(13 * i + 0)
40+
ret.append(random.randint(prompt_length_min, prompt_length_max))
41+
return ret
42+
43+
44+
def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
45+
return [[random.randint(100, 10000) for _ in range(pl)] for pl in prompt_lengths]
46+
47+
48+
def get_server(path_server: str, path_log: Optional[str]) -> dict:
3049
logger.info("Starting the llama.cpp server...")
31-
address = f"http://localhost:{port}"
32-
33-
popen_args: list[str] = [
34-
path_server,
35-
"--flash-attn",
36-
"--n-gpu-layers", str(n_gpu_layers),
37-
"--parallel", str(parallel),
38-
"--ctx-size", str(parallel * ctx_size),
39-
"--model", path_model,
40-
"--port", str(port),
41-
"--swa-full", # FIXME performance bad otherwise
42-
# "--attn-streams",
43-
]
44-
fout = open("bench.log", "w") if path_log is not None else subprocess.DEVNULL
45-
process = subprocess.Popen(popen_args, stdout=fout, stderr=subprocess.STDOUT)
50+
hostname: str = os.environ.get("LLAMA_ARG_HOST", "127.0.0.1")
51+
port: str = os.environ.get("LLAMA_ARG_PORT", "8080")
52+
address: str = f"http://{hostname}:{port}"
53+
54+
fout = open(path_log, "w") if path_log is not None else subprocess.DEVNULL
55+
process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT)
4656

4757
n_failures: int = 0
4858
while True:
4959
try:
5060
sleep(1.0)
5161
exit_code = process.poll()
5262
if exit_code is not None:
53-
raise RuntimeError(f"llama.cpp server for {path_model} exited unexpectedly with exit code {exit_code}")
63+
raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}, see {path_log}")
5464
response = requests.get(f"{address}/health")
5565
if response.status_code == 200:
5666
break
5767
except requests.ConnectionError:
5868
n_failures += 1
5969
if n_failures >= 10:
60-
raise RuntimeError(f"llama.cpp server for {path_model} is not healthy after 10 seconds")
70+
raise RuntimeError(f"llama.cpp server is not healthy after 10 seconds")
6171

6272
return {"process": process, "address": address, "fout": fout}
6373

@@ -87,76 +97,118 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
8797
session = data["session"]
8898
server_address: str = data["server_address"]
8999

90-
response = session.post(
91-
f"{server_address}/apply-template",
92-
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
93-
)
94-
if response.status_code != 200:
95-
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
96-
prompt: str = json.loads(response.text)["prompt"]
97-
98-
json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
99-
response = session.post(f"{server_address}/completion", json=json_data, stream=True)
100+
t_submit = time()
101+
if data["synthetic_prompt"]:
102+
json_data: dict = {
103+
"prompt": data["prompt"], "ignore_eos": True, "cache_prompt": False,
104+
"seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
105+
response = session.post(f"{server_address}/completion", json=json_data, stream=True)
106+
else:
107+
response = session.post(
108+
f"{server_address}/apply-template",
109+
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
110+
)
111+
if response.status_code != 200:
112+
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
113+
prompt: str = json.loads(response.text)["prompt"]
114+
115+
json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
116+
response = session.post(f"{server_address}/completion", json=json_data, stream=True)
100117

101118
last_valid_line: str = ""
102119
token_arrival_times: list[float] = []
103-
for line in response.iter_lines(decode_unicode=True):
104-
if not line.startswith("data: "):
120+
for line in response.iter_lines(decode_unicode=False):
121+
if not line.startswith(b"data: "):
105122
continue
106123
last_valid_line = line
107124
token_arrival_times.append(time())
108125
token_arrival_times = token_arrival_times[:-1]
109126

110127
if response.status_code != 200:
111128
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
112-
timings: dict = json.loads(last_valid_line[6:])["timings"]
113129

114-
return (timings["prompt_ms"], token_arrival_times)
115-
116-
117-
def benchmark(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int, n_prompts: int, n_predict: int):
118-
num_workers: int = parallel + 1
119-
prompts: list[str] = get_prompts(n_prompts)
130+
return (t_submit, token_arrival_times)
131+
132+
133+
def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int):
134+
if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
135+
logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
136+
os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
137+
if os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None:
138+
logger.info("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999")
139+
os.environ["LLAMA_ARG_N_GPU_LAYERS"] = "999"
140+
if os.environ.get("LLAMA_ARG_FLASH_ATTN") is None:
141+
logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'")
142+
os.environ["LLAMA_ARG_FLASH_ATTN"] = "true"
143+
144+
parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL", 1))
145+
prompts: Optional[list[str]] = get_prompts_text(prompt_source, n_prompts)
146+
synthetic_prompts: bool = prompts is None
147+
prompt_n = []
148+
149+
if synthetic_prompts:
150+
prompt_source_split: list[str] = prompt_source.split("-")
151+
assert len(prompt_source_split) == 3
152+
assert prompt_source_split[0].lower() == "rng"
153+
prompt_length_min: int = int(prompt_source_split[1])
154+
prompt_length_max: int = int(prompt_source_split[2])
155+
logger.info("Generating random prompts...")
156+
prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max)
157+
prompts = get_prompts_rng(prompt_n)
158+
else:
159+
n_predict_min = n_predict
160+
161+
if os.environ.get("LLAMA_ARG_CTX_SIZE") is None:
162+
context_per_slot: int = int(1.05 * (n_predict + (np.max(prompt_n) if synthetic_prompts else 2048)))
163+
context_total: int = context_per_slot * parallel
164+
os.environ["LLAMA_ARG_CTX_SIZE"] = str(context_total)
165+
logger.info(f"LLAMA_ARG_CTX_SIZE not explicitly set, using {context_total} ({context_per_slot} per slot).")
120166

121167
server: Optional[dict] = None
122168
session = None
123169
try:
124-
server = get_server(path_server, path_model, path_log, port, n_gpu_layers, parallel, ctx_size)
170+
server = get_server(path_server, path_log)
125171
server_address: str = server["address"]
126172

127-
adapter = requests.adapters.HTTPAdapter(pool_connections=num_workers, pool_maxsize=num_workers) # type: ignore
173+
adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore
128174
session = requests.Session()
129175
session.mount("http://", adapter)
130176
session.mount("https://", adapter)
131177

132178
data: list[dict] = []
179+
133180
for i, p in enumerate(prompts):
134-
data.append({"session": session, "server_address": server_address, "prompt": p, "n_predict": n_predict, "seed": i})
181+
random.seed(13 * i + 1)
182+
data.append({
183+
"session": session, "server_address": server_address, "prompt": p, "synthetic_prompt": synthetic_prompts,
184+
"n_predict": random.randint(n_predict_min, n_predict), "seed": 13 * i + 2})
135185

136-
logger.info("Getting the prompt lengths...")
137-
prompt_n = [get_prompt_length(d) for d in data]
186+
if not synthetic_prompts:
187+
logger.info("Getting the prompt lengths...")
188+
prompt_n = [get_prompt_length(d) for d in data]
138189

139190
logger.info("Starting the benchmark...\n")
140191
t0 = time()
141-
results: list[tuple[int, list[float]]] = thread_map(send_prompt, data, max_workers=num_workers, chunksize=1)
192+
results: list[tuple[float, list[float]]] = thread_map(send_prompt, data, max_workers=parallel, chunksize=1)
142193
finally:
143194
if server is not None:
144195
server["process"].terminate()
145196
server["process"].wait()
146197
if session is not None:
147198
session.close()
148199

149-
prompt_ms = []
200+
prompt_t = []
150201
token_t = []
151202
depth_sum: int = 0
152-
for pn, (pms, tat) in zip(prompt_n, results):
153-
prompt_ms.append(pms)
203+
for pn, (t_submit, tat) in zip(prompt_n, results):
204+
prompt_t.append(tat[0] - t_submit)
154205
token_t += tat
155206
n_tokens: int = len(tat)
156207
depth_sum += n_tokens * pn
157208
depth_sum += n_tokens * (n_tokens + 1) // 2
209+
assert len(token_t) > 0
158210
prompt_n = np.array(prompt_n, dtype=np.int64)
159-
prompt_ms = np.array(prompt_ms, dtype=np.float64)
211+
prompt_t = np.array(prompt_t, dtype=np.float64)
160212
token_t = np.array(token_t, dtype=np.float64)
161213

162214
token_t -= t0
@@ -167,18 +219,21 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
167219
logger.info(f"Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min")
168220
logger.info(f"Total prompt length: {np.sum(prompt_n)} tokens")
169221
logger.info(f"Average prompt length: {np.mean(prompt_n):.2f} tokens")
170-
logger.info(f"Average prompt latency: {np.mean(prompt_ms):.2f} ms")
171-
logger.info(f"Average prompt speed: {np.sum(prompt_n) / (1e-3 * np.sum(prompt_ms)):.2f} tokens/s")
222+
logger.info(f"Average prompt latency: {1e3 * np.mean(prompt_t):.2f} ms")
223+
logger.info(f"Average prompt speed: {np.sum(prompt_n) / np.sum(prompt_t):.2f} tokens/s")
172224
logger.info(f"Total generated tokens: {token_t.shape[0]}")
173225
logger.info(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens")
174226
logger.info(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s")
175227
logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
228+
logger.info("")
229+
logger.info(
230+
"The above numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
231+
"particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).")
176232

177233
plt.figure()
178-
plt.scatter(prompt_n, prompt_ms, s=10.0, marker=".", alpha=0.25)
179-
plt.xlim(0, 1.05 * np.max(prompt_n))
180-
plt.ylim(0, 1.05 * np.max(prompt_ms))
181-
plt.title(path_model)
234+
plt.scatter(prompt_n, 1e3 * prompt_t, s=10.0, marker=".", alpha=0.25)
235+
plt.xlim(0, 1.05e0 * np.max(prompt_n))
236+
plt.ylim(0, 1.05e3 * np.max(prompt_t))
182237
plt.xlabel("Prompt length [tokens]")
183238
plt.ylabel("Time to first token [ms]")
184239
plt.savefig("prompt_time.png", dpi=240)
@@ -187,7 +242,6 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
187242
plt.figure()
188243
plt.hist(token_t, np.arange(0, bin_max))
189244
plt.xlim(0, bin_max + 1)
190-
plt.title(path_model)
191245
plt.xlabel("Time [s]")
192246
plt.ylabel("Num. tokens generated per second")
193247
plt.savefig("gen_rate.png", dpi=240)
@@ -196,15 +250,18 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
196250
if __name__ == "__main__":
197251
parser = argparse.ArgumentParser(
198252
description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
199-
"Results are printed to console and visualized as plots (saved to current working directory).")
253+
"Results are printed to console and visualized as plots (saved to current working directory). "
254+
"To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help).")
200255
parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
201-
parser.add_argument("--path_model", type=str, required=True, help="Path to the model to use for the benchmark")
202-
parser.add_argument("--path_log", type=str, default=None, help="Path to the model to use for the benchmark")
203-
parser.add_argument("--port", type=int, default=18725, help="Port to use for the server during the benchmark")
204-
parser.add_argument("--n_gpu_layers", type=int, default=999, help="Number of GPU layers for the server")
205-
parser.add_argument("--parallel", type=int, default=16, help="Number of slots for the server")
206-
parser.add_argument("--ctx_size", type=int, default=4096, help="Server context size per slot")
207-
parser.add_argument("--n_prompts", type=int, default=1000, help="Number of prompts to evaluate")
256+
parser.add_argument("--path_log", type=str, default="server-bench.log", help="Path to the model to use for the benchmark")
257+
parser.add_argument(
258+
"--prompt_source", type=str, default="rng-1024-2048",
259+
help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
260+
"rng-MIN-MAX for synthetic prompts with random lengths in the interval [MIN, MAX]")
261+
parser.add_argument("--n_prompts", type=int, default=100, help="Number of prompts to evaluate")
208262
parser.add_argument("--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt")
263+
parser.add_argument(
264+
"--n_predict_min", type=int, default=1024,
265+
help="Min. number of tokens to predict per prompt (supported for synthetic prompts only)")
209266
args = parser.parse_args()
210267
benchmark(**vars(args))

tools/server/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp.
77
**Features:**
88
* LLM inference of F16 and quantized models on GPU and CPU
99
* [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes
10-
* Reranking endoint (https://github.com/ggml-org/llama.cpp/pull/9510)
10+
* Reranking endpoint (https://github.com/ggml-org/llama.cpp/pull/9510)
1111
* Parallel decoding with multi-user support
1212
* Continuous batching
1313
* Multimodal ([documentation](../../docs/multimodal.md)) / with OpenAI-compatible API support

0 commit comments

Comments
 (0)