@@ -32,11 +32,12 @@ def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]:
32
32
return ret
33
33
34
34
35
- def get_prompt_lengths_rng (n_prompts : int , prompt_length_min : int , prompt_length_max : int ) -> list [int ]:
35
+ def get_prompt_lengths_rng (n_prompts : int , prompt_length_min : int , prompt_length_max : int , seed_offset : int ) -> list [int ]:
36
36
assert n_prompts >= 0
37
37
ret : list [int ] = []
38
38
for i in range (n_prompts ):
39
- random .seed (13 * i + 0 )
39
+ if seed_offset >= 0 :
40
+ random .seed (3 * (seed_offset + 1000 * i ) + 0 )
40
41
ret .append (random .randint (prompt_length_min , prompt_length_max ))
41
42
return ret
42
43
@@ -46,12 +47,20 @@ def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
46
47
47
48
48
49
def get_server (path_server : str , path_log : Optional [str ]) -> dict :
49
- logger .info ("Starting the llama.cpp server..." )
50
- hostname : str = os .environ .get ("LLAMA_ARG_HOST" , "127.0.0.1" )
51
- port : str = os .environ .get ("LLAMA_ARG_PORT" , "8080" )
50
+ if os .environ .get ("LLAMA_ARG_HOST" ) is None :
51
+ logger .info ("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1" )
52
+ os .environ ["LLAMA_ARG_HOST" ] = "127.0.0.1"
53
+ if os .environ .get ("LLAMA_ARG_PORT" ) is None :
54
+ logger .info ("LLAMA_ARG_PORT not explicitly set, using 8080" )
55
+ os .environ ["LLAMA_ARG_PORT" ] = "8080"
56
+ hostname : Optional [str ] = os .environ .get ("LLAMA_ARG_HOST" )
57
+ port : Optional [str ] = os .environ .get ("LLAMA_ARG_PORT" )
58
+ assert hostname is not None
59
+ assert port is not None
52
60
address : str = f"http://{ hostname } :{ port } "
61
+ logger .info (f"Starting the llama.cpp server under { address } ..." )
53
62
54
- fout = open (path_log , "w" ) if path_log is not None else subprocess .DEVNULL
63
+ fout = open (path_log . format ( port = port ) , "w" ) if path_log is not None else subprocess .DEVNULL
55
64
process = subprocess .Popen ([path_server ], stdout = fout , stderr = subprocess .STDOUT )
56
65
57
66
n_failures : int = 0
@@ -60,7 +69,7 @@ def get_server(path_server: str, path_log: Optional[str]) -> dict:
60
69
sleep (1.0 )
61
70
exit_code = process .poll ()
62
71
if exit_code is not None :
63
- raise RuntimeError (f"llama.cpp server exited unexpectedly with exit code { exit_code } , see { path_log } " )
72
+ raise RuntimeError (f"llama.cpp server exited unexpectedly with exit code { exit_code } { path_log and f' , see { path_log . format ( port = port ) } ' or '' } " )
64
73
response = requests .get (f"{ address } /health" )
65
74
if response .status_code == 200 :
66
75
break
@@ -128,7 +137,7 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
128
137
return (t_submit , token_arrival_times )
129
138
130
139
131
- def benchmark (path_server : str , path_log : Optional [str ], prompt_source : str , n_prompts : int , n_predict : int , n_predict_min : int ):
140
+ def benchmark (path_server : str , path_log : Optional [str ], prompt_source : str , n_prompts : int , n_predict : int , n_predict_min : int , seed_offset : int ):
132
141
if os .environ .get ("LLAMA_ARG_N_PARALLEL" ) is None :
133
142
logger .info ("LLAMA_ARG_N_PARALLEL not explicitly set, using 32" )
134
143
os .environ ["LLAMA_ARG_N_PARALLEL" ] = "32"
@@ -139,7 +148,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
139
148
logger .info ("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'" )
140
149
os .environ ["LLAMA_ARG_FLASH_ATTN" ] = "true"
141
150
142
- parallel : int = int (os .environ .get ("LLAMA_ARG_N_PARALLEL" , 1 ))
151
+ parallel : int = int (os .environ .get ("LLAMA_ARG_N_PARALLEL" )) # type: ignore
143
152
prompts : Union [None , list [str ], list [list [int ]]] = get_prompts_text (prompt_source , n_prompts )
144
153
synthetic_prompts : bool = prompts is None
145
154
prompt_n = []
@@ -151,7 +160,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
151
160
prompt_length_min : int = int (prompt_source_split [1 ])
152
161
prompt_length_max : int = int (prompt_source_split [2 ])
153
162
logger .info ("Generating random prompts..." )
154
- prompt_n = get_prompt_lengths_rng (n_prompts , prompt_length_min , prompt_length_max )
163
+ prompt_n = get_prompt_lengths_rng (n_prompts , prompt_length_min , prompt_length_max , seed_offset )
155
164
prompts = get_prompts_rng (prompt_n )
156
165
else :
157
166
n_predict_min = n_predict
@@ -176,10 +185,11 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
176
185
data : list [dict ] = []
177
186
178
187
for i , p in enumerate (prompts ):
179
- random .seed (13 * i + 1 )
188
+ if seed_offset >= 0 :
189
+ random .seed (3 * (seed_offset + 1000 * i ) + 1 )
180
190
data .append ({
181
191
"session" : session , "server_address" : server_address , "prompt" : p , "synthetic_prompt" : synthetic_prompts ,
182
- "n_predict" : random .randint (n_predict_min , n_predict ), "seed" : 13 * i + 2 })
192
+ "n_predict" : random .randint (n_predict_min , n_predict ), "seed" : ( 3 * ( seed_offset + 1000 * i ) + 2 ) if seed_offset >= 0 else - 1 })
183
193
184
194
if not synthetic_prompts :
185
195
logger .info ("Getting the prompt lengths..." )
@@ -251,7 +261,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
251
261
"Results are printed to console and visualized as plots (saved to current working directory). "
252
262
"To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help)." )
253
263
parser .add_argument ("--path_server" , type = str , default = "llama-server" , help = "Path to the llama.cpp server binary" )
254
- parser .add_argument ("--path_log" , type = str , default = "server-bench.log" , help = "Path to the model to use for the benchmark" )
264
+ parser .add_argument ("--path_log" , type = str , default = "server-bench-{port} .log" , help = "Path to the model to use for the benchmark" )
255
265
parser .add_argument (
256
266
"--prompt_source" , type = str , default = "rng-1024-2048" ,
257
267
help = "How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
@@ -261,5 +271,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
261
271
parser .add_argument (
262
272
"--n_predict_min" , type = int , default = 1024 ,
263
273
help = "Min. number of tokens to predict per prompt (supported for synthetic prompts only)" )
274
+ parser .add_argument ("--seed_offset" , type = int , default = 0 , help = "Offset for determining the seeds for pseudorandom prompt/generation lengths. "
275
+ "Corelations between seeds can occur when set >= 1000. Negative values mean no seed." )
264
276
args = parser .parse_args ()
265
277
benchmark (** vars (args ))
0 commit comments