@@ -32,11 +32,12 @@ def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]:
3232 return ret
3333
3434
35- def get_prompt_lengths_rng (n_prompts : int , prompt_length_min : int , prompt_length_max : int ) -> list [int ]:
35+ def get_prompt_lengths_rng (n_prompts : int , prompt_length_min : int , prompt_length_max : int , seed_offset : int ) -> list [int ]:
3636 assert n_prompts >= 0
3737 ret : list [int ] = []
3838 for i in range (n_prompts ):
39- random .seed (13 * i + 0 )
39+ if seed_offset >= 0 :
40+ random .seed (3 * (seed_offset + 1000 * i ) + 0 )
4041 ret .append (random .randint (prompt_length_min , prompt_length_max ))
4142 return ret
4243
@@ -46,12 +47,18 @@ def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
4647
4748
4849def get_server (path_server : str , path_log : Optional [str ]) -> dict :
49- logger .info ("Starting the llama.cpp server..." )
50- hostname : str = os .environ .get ("LLAMA_ARG_HOST" , "127.0.0.1" )
51- port : str = os .environ .get ("LLAMA_ARG_PORT" , "8080" )
50+ if os .environ .get ("LLAMA_ARG_HOST" ) is None :
51+ logger .info ("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1" )
52+ os .environ ["LLAMA_ARG_HOST" ] = "127.0.0.1"
53+ if os .environ .get ("LLAMA_ARG_PORT" ) is None :
54+ logger .info ("LLAMA_ARG_PORT not explicitly set, using 8080" )
55+ os .environ ["LLAMA_ARG_PORT" ] = "8080"
56+ hostname : str = os .environ .get ("LLAMA_ARG_HOST" )
57+ port : str = os .environ .get ("LLAMA_ARG_PORT" )
5258 address : str = f"http://{ hostname } :{ port } "
59+ logger .info (f"Starting the llama.cpp server under { address } ..." )
5360
54- fout = open (path_log , "w" ) if path_log is not None else subprocess .DEVNULL
61+ fout = open (path_log . format ( port = port ) , "w" ) if path_log is not None else subprocess .DEVNULL
5562 process = subprocess .Popen ([path_server ], stdout = fout , stderr = subprocess .STDOUT )
5663
5764 n_failures : int = 0
@@ -128,7 +135,7 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
128135 return (t_submit , token_arrival_times )
129136
130137
131- def benchmark (path_server : str , path_log : Optional [str ], prompt_source : str , n_prompts : int , n_predict : int , n_predict_min : int ):
138+ def benchmark (path_server : str , path_log : Optional [str ], prompt_source : str , n_prompts : int , n_predict : int , n_predict_min : int , seed_offset : int ):
132139 if os .environ .get ("LLAMA_ARG_N_PARALLEL" ) is None :
133140 logger .info ("LLAMA_ARG_N_PARALLEL not explicitly set, using 32" )
134141 os .environ ["LLAMA_ARG_N_PARALLEL" ] = "32"
@@ -139,7 +146,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
139146 logger .info ("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'" )
140147 os .environ ["LLAMA_ARG_FLASH_ATTN" ] = "true"
141148
142- parallel : int = int (os .environ .get ("LLAMA_ARG_N_PARALLEL" , 1 ))
149+ parallel : int = int (os .environ .get ("LLAMA_ARG_N_PARALLEL" ))
143150 prompts : Union [None , list [str ], list [list [int ]]] = get_prompts_text (prompt_source , n_prompts )
144151 synthetic_prompts : bool = prompts is None
145152 prompt_n = []
@@ -151,7 +158,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
151158 prompt_length_min : int = int (prompt_source_split [1 ])
152159 prompt_length_max : int = int (prompt_source_split [2 ])
153160 logger .info ("Generating random prompts..." )
154- prompt_n = get_prompt_lengths_rng (n_prompts , prompt_length_min , prompt_length_max )
161+ prompt_n = get_prompt_lengths_rng (n_prompts , prompt_length_min , prompt_length_max , seed_offset )
155162 prompts = get_prompts_rng (prompt_n )
156163 else :
157164 n_predict_min = n_predict
@@ -176,10 +183,11 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
176183 data : list [dict ] = []
177184
178185 for i , p in enumerate (prompts ):
179- random .seed (13 * i + 1 )
186+ if seed_offset >= 0 :
187+ random .seed (3 * (seed_offset + 1000 * i ) + 1 )
180188 data .append ({
181189 "session" : session , "server_address" : server_address , "prompt" : p , "synthetic_prompt" : synthetic_prompts ,
182- "n_predict" : random .randint (n_predict_min , n_predict ), "seed" : 13 * i + 2 })
190+ "n_predict" : random .randint (n_predict_min , n_predict ), "seed" : 3 * ( seed_offset + 1000 * i ) + 2 if seed_offset >= 0 else - 1 })
183191
184192 if not synthetic_prompts :
185193 logger .info ("Getting the prompt lengths..." )
@@ -251,7 +259,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
251259 "Results are printed to console and visualized as plots (saved to current working directory). "
252260 "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help)." )
253261 parser .add_argument ("--path_server" , type = str , default = "llama-server" , help = "Path to the llama.cpp server binary" )
254- parser .add_argument ("--path_log" , type = str , default = "server-bench.log" , help = "Path to the model to use for the benchmark" )
262+ parser .add_argument ("--path_log" , type = str , default = "server-bench-{port} .log" , help = "Path to the model to use for the benchmark" )
255263 parser .add_argument (
256264 "--prompt_source" , type = str , default = "rng-1024-2048" ,
257265 help = "How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
@@ -261,5 +269,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
261269 parser .add_argument (
262270 "--n_predict_min" , type = int , default = 1024 ,
263271 help = "Min. number of tokens to predict per prompt (supported for synthetic prompts only)" )
272+ parser .add_argument ("--seed_offset" , type = int , default = 0 , help = "Offset for determining the seeds fpr pseudorandom prompt/generation lengths. "
273+ "Corelations between seeds can occur when set >= 1000. Negative values mean no seed." )
264274 args = parser .parse_args ()
265275 benchmark (** vars (args ))
0 commit comments