Skip to content

Commit bbd0f91

Browse files
server-bench: make seed choice configurable (#14929)
* server-bench: make seed choice configurable * Update scripts/server-bench.py Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update scripts/server-bench.py Co-authored-by: Sigbjørn Skjæret <[email protected]> * fix error formatting * Update scripts/server-bench.py Co-authored-by: Sigbjørn Skjæret <[email protected]> --------- Co-authored-by: Sigbjørn Skjæret <[email protected]>
1 parent 0a5036b commit bbd0f91

File tree

1 file changed

+25
-13
lines changed

1 file changed

+25
-13
lines changed

scripts/server-bench.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@ def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]:
3232
return ret
3333

3434

35-
def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int) -> list[int]:
35+
def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int, seed_offset: int) -> list[int]:
3636
assert n_prompts >= 0
3737
ret: list[int] = []
3838
for i in range(n_prompts):
39-
random.seed(13 * i + 0)
39+
if seed_offset >= 0:
40+
random.seed(3 * (seed_offset + 1000 * i) + 0)
4041
ret.append(random.randint(prompt_length_min, prompt_length_max))
4142
return ret
4243

@@ -46,12 +47,20 @@ def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
4647

4748

4849
def get_server(path_server: str, path_log: Optional[str]) -> dict:
49-
logger.info("Starting the llama.cpp server...")
50-
hostname: str = os.environ.get("LLAMA_ARG_HOST", "127.0.0.1")
51-
port: str = os.environ.get("LLAMA_ARG_PORT", "8080")
50+
if os.environ.get("LLAMA_ARG_HOST") is None:
51+
logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1")
52+
os.environ["LLAMA_ARG_HOST"] = "127.0.0.1"
53+
if os.environ.get("LLAMA_ARG_PORT") is None:
54+
logger.info("LLAMA_ARG_PORT not explicitly set, using 8080")
55+
os.environ["LLAMA_ARG_PORT"] = "8080"
56+
hostname: Optional[str] = os.environ.get("LLAMA_ARG_HOST")
57+
port: Optional[str] = os.environ.get("LLAMA_ARG_PORT")
58+
assert hostname is not None
59+
assert port is not None
5260
address: str = f"http://{hostname}:{port}"
61+
logger.info(f"Starting the llama.cpp server under {address}...")
5362

54-
fout = open(path_log, "w") if path_log is not None else subprocess.DEVNULL
63+
fout = open(path_log.format(port=port), "w") if path_log is not None else subprocess.DEVNULL
5564
process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT)
5665

5766
n_failures: int = 0
@@ -60,7 +69,7 @@ def get_server(path_server: str, path_log: Optional[str]) -> dict:
6069
sleep(1.0)
6170
exit_code = process.poll()
6271
if exit_code is not None:
63-
raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}, see {path_log}")
72+
raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}{path_log and f', see {path_log.format(port=port)}' or ''}")
6473
response = requests.get(f"{address}/health")
6574
if response.status_code == 200:
6675
break
@@ -128,7 +137,7 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
128137
return (t_submit, token_arrival_times)
129138

130139

131-
def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int):
140+
def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int, seed_offset: int):
132141
if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
133142
logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
134143
os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
@@ -139,7 +148,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
139148
logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'")
140149
os.environ["LLAMA_ARG_FLASH_ATTN"] = "true"
141150

142-
parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL", 1))
151+
parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore
143152
prompts: Union[None, list[str], list[list[int]]] = get_prompts_text(prompt_source, n_prompts)
144153
synthetic_prompts: bool = prompts is None
145154
prompt_n = []
@@ -151,7 +160,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
151160
prompt_length_min: int = int(prompt_source_split[1])
152161
prompt_length_max: int = int(prompt_source_split[2])
153162
logger.info("Generating random prompts...")
154-
prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max)
163+
prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max, seed_offset)
155164
prompts = get_prompts_rng(prompt_n)
156165
else:
157166
n_predict_min = n_predict
@@ -176,10 +185,11 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
176185
data: list[dict] = []
177186

178187
for i, p in enumerate(prompts):
179-
random.seed(13 * i + 1)
188+
if seed_offset >= 0:
189+
random.seed(3 * (seed_offset + 1000 * i) + 1)
180190
data.append({
181191
"session": session, "server_address": server_address, "prompt": p, "synthetic_prompt": synthetic_prompts,
182-
"n_predict": random.randint(n_predict_min, n_predict), "seed": 13 * i + 2})
192+
"n_predict": random.randint(n_predict_min, n_predict), "seed": (3 * (seed_offset + 1000 * i) + 2) if seed_offset >= 0 else -1})
183193

184194
if not synthetic_prompts:
185195
logger.info("Getting the prompt lengths...")
@@ -251,7 +261,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
251261
"Results are printed to console and visualized as plots (saved to current working directory). "
252262
"To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help).")
253263
parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
254-
parser.add_argument("--path_log", type=str, default="server-bench.log", help="Path to the model to use for the benchmark")
264+
parser.add_argument("--path_log", type=str, default="server-bench-{port}.log", help="Path to the model to use for the benchmark")
255265
parser.add_argument(
256266
"--prompt_source", type=str, default="rng-1024-2048",
257267
help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
@@ -261,5 +271,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
261271
parser.add_argument(
262272
"--n_predict_min", type=int, default=1024,
263273
help="Min. number of tokens to predict per prompt (supported for synthetic prompts only)")
274+
parser.add_argument("--seed_offset", type=int, default=0, help="Offset for determining the seeds for pseudorandom prompt/generation lengths. "
275+
"Corelations between seeds can occur when set >= 1000. Negative values mean no seed.")
264276
args = parser.parse_args()
265277
benchmark(**vars(args))

0 commit comments

Comments
 (0)