Skip to content

Commit 2efa0c2

Browse files
author
ochafik
committed
tool-call: add weather tool e2e tests
1 parent 15ec01e commit 2efa0c2

File tree

1 file changed

+86
-18
lines changed

1 file changed

+86
-18
lines changed

examples/server/tests/unit/test_chat_completion.py

Lines changed: 86 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,23 @@ def test_chat_completion_with_timings_per_token():
221221
}
222222
}
223223

224+
WEATHER_TOOL = {
225+
"type":"function",
226+
"function":{
227+
"name":"get_current_weather",
228+
"description":"Get the current weather in a given location",
229+
"parameters":{
230+
"type":"object",
231+
"properties":{
232+
"location":{
233+
"type":"string",
234+
"description":"The city and state, e.g. San Francisco, CA"
235+
}
236+
},
237+
"required":["location"]
238+
}
239+
}
240+
}
224241

225242
@pytest.mark.parametrize("template_name,tool,argument_key", [
226243
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
@@ -308,22 +325,76 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
308325

309326

310327
@pytest.mark.slow
311-
@pytest.mark.parametrize("tool,expected_arguments,hf_repo,hf_file,template_override", [
312-
(PYTHON_TOOL, None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
313-
(PYTHON_TOOL, None, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
314-
(PYTHON_TOOL, None, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
315-
(PYTHON_TOOL, None, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
316-
(PYTHON_TOOL, None, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
317-
(PYTHON_TOOL, None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
318-
(PYTHON_TOOL, None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", None),
328+
@pytest.mark.parametrize("expected_arguments,hf_repo,hf_file,template_override", [
329+
(None, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
330+
(None, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
331+
(None, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
332+
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
333+
(None, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
334+
(None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
335+
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", None),
319336
# TODO: fix these models
320-
(PYTHON_TOOL, '{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
321-
# (PYTHON_TOOL, None, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
322-
# (PYTHON_TOOL, None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
337+
# (None, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q6_K_L.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
338+
# (None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q6_K_L.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
339+
# (None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
323340
])
324-
def test_hello_world_tool_call(tool: dict, expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
341+
def test_weather_tool_call(expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
325342
global server
326-
server.n_slots = 2
343+
server.n_slots = 1
344+
server.jinja = True
345+
server.n_ctx = 8192
346+
server.n_predict = 128
347+
server.model_hf_repo = hf_repo
348+
server.model_hf_file = hf_file
349+
if template_override:
350+
(template_hf_repo, template_variant) = template_override
351+
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
352+
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_hf_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
353+
server.start(timeout_seconds=15*60)
354+
res = server.make_request("POST", "/chat/completions", data={
355+
"max_tokens": 256,
356+
"messages": [
357+
{"role": "user", "content": "What is the weather in Istanbul?"},
358+
],
359+
"tools": [WEATHER_TOOL],
360+
# "temperature": 0.5,
361+
# "top_k": 10,
362+
# "top_p": 0.9,
363+
})
364+
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
365+
choice = res.body["choices"][0]
366+
tool_calls = choice["message"].get("tool_calls")
367+
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
368+
tool_call = tool_calls[0]
369+
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"]
370+
actual_arguments = tool_call["function"]["arguments"]
371+
if expected_arguments is not None:
372+
assert actual_arguments == expected_arguments
373+
else:
374+
actual_arguments = json.loads(actual_arguments)
375+
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
376+
location = actual_arguments["location"]
377+
assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}"
378+
assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
379+
380+
381+
@pytest.mark.slow
382+
@pytest.mark.parametrize("expected_arguments,hf_repo,hf_file,template_override", [
383+
('{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
384+
(None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
385+
(None, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
386+
(None, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
387+
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
388+
(None, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
389+
(None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
390+
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", None),
391+
# TODO: fix these models
392+
# (None, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q6_K_L.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
393+
# (None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q6_K_L.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
394+
])
395+
def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
396+
global server
397+
server.n_slots = 1
327398
server.jinja = True
328399
server.n_ctx = 8192
329400
server.n_predict = 128
@@ -341,7 +412,7 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: str | None, hf_re
341412
{"role": "user", "content": "say hello world with python"},
342413
# {"role": "user", "content": "Print a hello world message with python"},
343414
],
344-
"tools": [tool],
415+
"tools": [PYTHON_TOOL],
345416
"temperature": 0.5,
346417
"top_k": 10,
347418
"top_p": 0.9,
@@ -351,10 +422,7 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: str | None, hf_re
351422
tool_calls = choice["message"].get("tool_calls")
352423
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
353424
tool_call = tool_calls[0]
354-
if tool["type"] == "function":
355-
assert tool["function"]["name"] == tool_call["function"]["name"]
356-
elif tool["type"] == "code_interpreter":
357-
assert re.match('i?python', tool_call["function"]["name"])
425+
assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
358426
actual_arguments = tool_call["function"]["arguments"]
359427
if expected_arguments is not None:
360428
assert actual_arguments == expected_arguments

0 commit comments

Comments
 (0)