Skip to content

Commit a456911

Browse files
author
ochafik
committed
add scripts/tool_bench.sh & .py
1 parent b37779b commit a456911

File tree

4 files changed

+488
-41
lines changed

4 files changed

+488
-41
lines changed

examples/server/tests/unit/test_tool_call.py

100644100755
Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,7 @@ def create_server():
7474
}
7575

7676

77-
def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None):
78-
global server
79-
n_predict = 512
80-
# server = ServerPreset.stories15m_moe()
81-
server.jinja = True
82-
server.n_predict = n_predict
83-
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
84-
server.start(timeout_seconds=TIMEOUT_SERVER_START)
77+
def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs):
8578
res = server.make_request("POST", "/v1/chat/completions", data={
8679
"max_tokens": n_predict,
8780
"messages": [
@@ -91,6 +84,7 @@ def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, a
9184
"tool_choice": "required",
9285
"tools": [tool],
9386
"parallel_tool_calls": False,
87+
**kwargs,
9488
})
9589
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
9690
choice = res.body["choices"][0]
@@ -113,7 +107,14 @@ def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, a
113107
("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
114108
])
115109
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
116-
do_test_completion_with_required_tool_tiny(template_name, tool, argument_key)
110+
global server
111+
n_predict = 512
112+
# server = ServerPreset.stories15m_moe()
113+
server.jinja = True
114+
server.n_predict = n_predict
115+
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
116+
server.start(timeout_seconds=TIMEOUT_SERVER_START)
117+
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, temperature=0.0, top_k=1, top_p=1.0)
117118

118119

119120
@pytest.mark.slow
@@ -138,7 +139,14 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict,
138139
("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"),
139140
])
140141
def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None):
141-
do_test_completion_with_required_tool_tiny(template_name, tool, argument_key)
142+
global server
143+
n_predict = 512
144+
# server = ServerPreset.stories15m_moe()
145+
server.jinja = True
146+
server.n_predict = n_predict
147+
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
148+
server.start(timeout_seconds=TIMEOUT_SERVER_START)
149+
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict)
142150

143151

144152
@pytest.mark.slow
@@ -234,12 +242,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
234242
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
235243

236244

237-
def do_test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
238-
global server
239-
server.jinja = True
240-
server.n_predict = n_predict
241-
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
242-
server.start(timeout_seconds=TIMEOUT_SERVER_START)
245+
def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs):
243246
res = server.make_request("POST", "/v1/chat/completions", data={
244247
"max_tokens": n_predict,
245248
"messages": [
@@ -248,6 +251,7 @@ def do_test_completion_without_tool_call(template_name: str, n_predict: int, too
248251
],
249252
"tools": tools if tools else None,
250253
"tool_choice": tool_choice,
254+
**kwargs,
251255
}, timeout=TIMEOUT_HTTP_REQUEST)
252256
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
253257
choice = res.body["choices"][0]
@@ -260,7 +264,12 @@ def do_test_completion_without_tool_call(template_name: str, n_predict: int, too
260264
("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'),
261265
])
262266
def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
263-
do_test_completion_without_tool_call(template_name, n_predict, tools, tool_choice)
267+
global server
268+
server.jinja = True
269+
server.n_predict = n_predict
270+
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
271+
server.start(timeout_seconds=TIMEOUT_SERVER_START)
272+
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
264273

265274

266275
@pytest.mark.slow
@@ -276,7 +285,12 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t
276285
("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'),
277286
])
278287
def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
279-
do_test_completion_without_tool_call(template_name, n_predict, tools, tool_choice)
288+
global server
289+
server.jinja = True
290+
server.n_predict = n_predict
291+
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
292+
server.start(timeout_seconds=TIMEOUT_SERVER_START)
293+
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
280294

281295

282296
@pytest.mark.slow
@@ -333,13 +347,17 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
333347
elif isinstance(template_override, str):
334348
server.chat_template = template_override
335349
server.start(timeout_seconds=TIMEOUT_SERVER_START)
350+
do_test_weather(server, max_tokens=n_predict)
351+
352+
353+
def do_test_weather(server: ServerProcess, **kwargs):
336354
res = server.make_request("POST", "/v1/chat/completions", data={
337-
"max_tokens": n_predict,
338355
"messages": [
339356
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
340357
{"role": "user", "content": "What is the weather in Istanbul?"},
341358
],
342359
"tools": [WEATHER_TOOL],
360+
**kwargs,
343361
}, timeout=TIMEOUT_HTTP_REQUEST)
344362
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
345363
choice = res.body["choices"][0]
@@ -387,6 +405,10 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
387405
elif isinstance(template_override, str):
388406
server.chat_template = template_override
389407
server.start(timeout_seconds=TIMEOUT_SERVER_START)
408+
do_test_calc_result(server, result_override, n_predict)
409+
410+
411+
def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs):
390412
res = server.make_request("POST", "/v1/chat/completions", data={
391413
"max_tokens": n_predict,
392414
"messages": [
@@ -431,7 +453,8 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
431453
}
432454
}
433455
}
434-
]
456+
],
457+
**kwargs,
435458
}, timeout=TIMEOUT_HTTP_REQUEST)
436459
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
437460
choice = res.body["choices"][0]
@@ -548,13 +571,18 @@ def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | Non
548571
elif isinstance(template_override, str):
549572
server.chat_template = template_override
550573
server.start(timeout_seconds=TIMEOUT_SERVER_START)
574+
575+
do_test_hello_world(server, max_tokens=n_predict)
576+
577+
578+
def do_test_hello_world(server: ServerProcess, **kwargs):
551579
res = server.make_request("POST", "/v1/chat/completions", data={
552-
"max_tokens": n_predict,
553580
"messages": [
554581
{"role": "system", "content": "You are a tool-calling agent."},
555582
{"role": "user", "content": "say hello world with python"},
556583
],
557584
"tools": [PYTHON_TOOL],
585+
**kwargs,
558586
}, timeout=TIMEOUT_HTTP_REQUEST)
559587
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
560588
choice = res.body["choices"][0]

examples/server/tests/utils.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828

2929
DEFAULT_HTTP_TIMEOUT = 12 if "LLAMA_SANITIZE" not in os.environ else 30
30-
30+
REQUEST_RETRIES = int(os.environ.get('LLAMA_SERVER_TEST_REQUEST_RETRIES', '1'))
3131

3232
class ServerResponse:
3333
headers: dict
@@ -81,6 +81,7 @@ class ServerProcess:
8181
reasoning_format: Literal['deepseek', 'none'] | None = None
8282
chat_template: str | None = None
8383
chat_template_file: str | None = None
84+
server_path: str | None = None
8485

8586
# session variables
8687
process: subprocess.Popen | None = None
@@ -94,7 +95,9 @@ def __init__(self):
9495
self.server_port = int(os.environ["PORT"])
9596

9697
def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
97-
if "LLAMA_SERVER_BIN_PATH" in os.environ:
98+
if self.server_path is not None:
99+
server_path = self.server_path
100+
elif "LLAMA_SERVER_BIN_PATH" in os.environ:
98101
server_path = os.environ["LLAMA_SERVER_BIN_PATH"]
99102
elif os.name == "nt":
100103
server_path = "../../../build/bin/Release/llama-server.exe"
@@ -181,7 +184,7 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
181184
server_args.extend(["--chat-template-file", self.chat_template_file])
182185

183186
args = [str(arg) for arg in [server_path, *server_args]]
184-
print(f"bench: starting server with: {' '.join(args)}")
187+
print(f"tests: starting server with: {' '.join(args)}")
185188

186189
flags = 0
187190
if "nt" == os.name:
@@ -212,6 +215,10 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
212215
return # server is ready
213216
except Exception as e:
214217
pass
218+
# Check if process died
219+
if self.process.poll() is not None:
220+
raise RuntimeError(f"Server process died with return code {self.process.returncode}")
221+
215222
print(f"Waiting for server to start...")
216223
time.sleep(0.5)
217224
raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")
@@ -233,23 +240,31 @@ def make_request(
233240
timeout: float | None = None,
234241
) -> ServerResponse:
235242
url = f"http://{self.server_host}:{self.server_port}{path}"
236-
parse_body = False
237-
if method == "GET":
238-
response = requests.get(url, headers=headers, timeout=timeout)
239-
parse_body = True
240-
elif method == "POST":
241-
response = requests.post(url, headers=headers, json=data, timeout=timeout)
242-
parse_body = True
243-
elif method == "OPTIONS":
244-
response = requests.options(url, headers=headers, timeout=timeout)
245-
else:
246-
raise ValueError(f"Unimplemented method: {method}")
247-
result = ServerResponse()
248-
result.headers = dict(response.headers)
249-
result.status_code = response.status_code
250-
result.body = response.json() if parse_body else None
251-
print("Response from server", json.dumps(result.body, indent=2))
252-
return result
243+
for remaining_attempts in range(REQUEST_RETRIES, 0, -1):
244+
# print(f"#\ncurl {url} -d '{json.dumps(data, indent=2)}'\n")
245+
parse_body = False
246+
if method == "GET":
247+
response = requests.get(url, headers=headers, timeout=timeout)
248+
parse_body = True
249+
elif method == "POST":
250+
response = requests.post(url, headers=headers, json=data, timeout=timeout)
251+
parse_body = True
252+
elif method == "OPTIONS":
253+
response = requests.options(url, headers=headers, timeout=timeout)
254+
else:
255+
raise ValueError(f"Unimplemented method: {method}")
256+
257+
if (response is None or response.status_code != 200) and remaining_attempts > 0:
258+
continue
259+
result = ServerResponse()
260+
result.headers = dict(response.headers)
261+
result.status_code = response.status_code
262+
result.body = response.json() if parse_body else None
263+
# print("Response from server", json.dumps(result.body, indent=2))
264+
return result
265+
266+
raise RuntimeError(f"Failed to make request to {url} after {retries} attempts")
267+
253268

254269
def make_stream_request(
255270
self,

0 commit comments

Comments
 (0)