Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions scripts/server-bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@
return ret


def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int) -> list[int]:
def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int, seed_offset: int) -> list[int]:
assert n_prompts >= 0
ret: list[int] = []
for i in range(n_prompts):
random.seed(13 * i + 0)
if seed_offset >= 0:
random.seed(3 * (seed_offset + 1000 * i) + 0)
ret.append(random.randint(prompt_length_min, prompt_length_max))
return ret

Expand All @@ -46,12 +47,20 @@


def get_server(path_server: str, path_log: Optional[str]) -> dict:
logger.info("Starting the llama.cpp server...")
hostname: str = os.environ.get("LLAMA_ARG_HOST", "127.0.0.1")
port: str = os.environ.get("LLAMA_ARG_PORT", "8080")
if os.environ.get("LLAMA_ARG_HOST") is None:
logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1")
os.environ["LLAMA_ARG_HOST"] = "127.0.0.1"
if os.environ.get("LLAMA_ARG_PORT") is None:
logger.info("LLAMA_ARG_PORT not explicitly set, using 8080")
os.environ["LLAMA_ARG_PORT"] = "8080"
hostname: Optional[str] = os.environ.get("LLAMA_ARG_HOST")
port: Optional[str] = os.environ.get("LLAMA_ARG_PORT")
assert hostname is not None
assert port is not None
address: str = f"http://{hostname}:{port}"
logger.info(f"Starting the llama.cpp server under {address}...")

fout = open(path_log, "w") if path_log is not None else subprocess.DEVNULL
fout = open(path_log.format(port=port), "w") if path_log is not None else subprocess.DEVNULL
process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT)

n_failures: int = 0
Expand All @@ -60,7 +69,7 @@
sleep(1.0)
exit_code = process.poll()
if exit_code is not None:
raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}, see {path_log}")
raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}, see {path_log.format(port=port)}")

Check failure on line 72 in scripts/server-bench.py

View workflow job for this annotation

GitHub Actions / pyright type-check

"format" is not a known attribute of "None" (reportOptionalMemberAccess)
response = requests.get(f"{address}/health")
if response.status_code == 200:
break
Expand Down Expand Up @@ -128,7 +137,7 @@
return (t_submit, token_arrival_times)


def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int):
def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int, seed_offset: int):
if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
Expand All @@ -139,7 +148,7 @@
logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'")
os.environ["LLAMA_ARG_FLASH_ATTN"] = "true"

parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL", 1))
parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore
prompts: Union[None, list[str], list[list[int]]] = get_prompts_text(prompt_source, n_prompts)
synthetic_prompts: bool = prompts is None
prompt_n = []
Expand All @@ -151,7 +160,7 @@
prompt_length_min: int = int(prompt_source_split[1])
prompt_length_max: int = int(prompt_source_split[2])
logger.info("Generating random prompts...")
prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max)
prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max, seed_offset)
prompts = get_prompts_rng(prompt_n)
else:
n_predict_min = n_predict
Expand All @@ -176,10 +185,11 @@
data: list[dict] = []

for i, p in enumerate(prompts):
random.seed(13 * i + 1)
if seed_offset >= 0:
random.seed(3 * (seed_offset + 1000 * i) + 1)
data.append({
"session": session, "server_address": server_address, "prompt": p, "synthetic_prompt": synthetic_prompts,
"n_predict": random.randint(n_predict_min, n_predict), "seed": 13 * i + 2})
"n_predict": random.randint(n_predict_min, n_predict), "seed": (3 * (seed_offset + 1000 * i) + 2) if seed_offset >= 0 else -1})

if not synthetic_prompts:
logger.info("Getting the prompt lengths...")
Expand Down Expand Up @@ -251,7 +261,7 @@
"Results are printed to console and visualized as plots (saved to current working directory). "
"To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help).")
parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
parser.add_argument("--path_log", type=str, default="server-bench.log", help="Path to the model to use for the benchmark")
parser.add_argument("--path_log", type=str, default="server-bench-{port}.log", help="Path to the model to use for the benchmark")
parser.add_argument(
"--prompt_source", type=str, default="rng-1024-2048",
help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
Expand All @@ -261,5 +271,7 @@
parser.add_argument(
"--n_predict_min", type=int, default=1024,
help="Min. number of tokens to predict per prompt (supported for synthetic prompts only)")
parser.add_argument("--seed_offset", type=int, default=0, help="Offset for determining the seeds for pseudorandom prompt/generation lengths. "
"Corelations between seeds can occur when set >= 1000. Negative values mean no seed.")
args = parser.parse_args()
benchmark(**vars(args))
Loading