Skip to content

Commit cafea60

Browse files
author
ochafik
committed
Split e2e test_tool_call from test_chat_completion
1 parent 90effb8 commit cafea60

File tree

2 files changed

+317
-234
lines changed

2 files changed

+317
-234
lines changed

examples/server/tests/unit/test_chat_completion.py

Lines changed: 0 additions & 234 deletions
Original file line numberDiff line numberDiff line change
@@ -188,240 +188,6 @@ def test_chat_completion_with_timings_per_token():
188188
assert data["timings"]["predicted_n"] <= 10
189189

190190

191-
TEST_TOOL = {
192-
"type":"function",
193-
"function": {
194-
"name": "test",
195-
"description": "",
196-
"parameters": {
197-
"type": "object",
198-
"properties": {
199-
"success": {"type": "boolean", "const": True},
200-
},
201-
"required": ["success"]
202-
}
203-
}
204-
}
205-
206-
PYTHON_TOOL = {
207-
"type": "function",
208-
"function": {
209-
"name": "python",
210-
"description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
211-
"parameters": {
212-
"type": "object",
213-
"properties": {
214-
"code": {
215-
"type": "string",
216-
"description": "The code to run in the ipython interpreter."
217-
}
218-
},
219-
"required": ["code"]
220-
}
221-
}
222-
}
223-
224-
WEATHER_TOOL = {
225-
"type":"function",
226-
"function":{
227-
"name":"get_current_weather",
228-
"description":"Get the current weather in a given location",
229-
"parameters":{
230-
"type":"object",
231-
"properties":{
232-
"location":{
233-
"type":"string",
234-
"description":"The city and state, e.g. San Francisco, CA"
235-
}
236-
},
237-
"required":["location"]
238-
}
239-
}
240-
}
241-
242-
@pytest.mark.parametrize("template_name,tool,argument_key", [
243-
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
244-
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
245-
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
246-
("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
247-
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
248-
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"),
249-
("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
250-
("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"),
251-
("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
252-
("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"),
253-
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
254-
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"),
255-
("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
256-
("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
257-
])
258-
def test_completion_with_required_tool(template_name: str, tool: dict, argument_key: str | None):
259-
n_predict = 512
260-
global server
261-
# server = ServerPreset.stories15m_moe()
262-
server.jinja = True
263-
server.n_predict = n_predict
264-
server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja'
265-
server.start()
266-
res = server.make_request("POST", "/chat/completions", data={
267-
"max_tokens": n_predict,
268-
"messages": [
269-
{"role": "system", "content": "You are a coding assistant."},
270-
{"role": "user", "content": "Write an example"},
271-
],
272-
"tool_choice": "required",
273-
"tools": [tool],
274-
"parallel_tool_calls": False,
275-
"temperature": 0.0,
276-
"top_k": 1,
277-
"top_p": 1.0,
278-
})
279-
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
280-
choice = res.body["choices"][0]
281-
tool_calls = choice["message"].get("tool_calls")
282-
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
283-
tool_call = tool_calls[0]
284-
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
285-
assert expected_function_name == tool_call["function"]["name"]
286-
actual_arguments = tool_call["function"]["arguments"]
287-
assert isinstance(actual_arguments, str)
288-
if argument_key is not None:
289-
actual_arguments = json.loads(actual_arguments)
290-
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
291-
292-
293-
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
294-
("meetkai-functionary-medium-v3.1", 128, [], None),
295-
("meetkai-functionary-medium-v3.1", 128, [TEST_TOOL], None),
296-
("meetkai-functionary-medium-v3.1", 128, [PYTHON_TOOL], 'none'),
297-
("meetkai-functionary-medium-v3.2", 128, [], None),
298-
("meetkai-functionary-medium-v3.2", 128, [TEST_TOOL], None),
299-
("meetkai-functionary-medium-v3.2", 128, [PYTHON_TOOL], 'none'),
300-
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, [], None),
301-
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, [TEST_TOOL], None),
302-
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, [PYTHON_TOOL], 'none'),
303-
])
304-
def test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
305-
global server
306-
server.jinja = True
307-
server.n_predict = n_predict
308-
server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja'
309-
server.start()
310-
res = server.make_request("POST", "/chat/completions", data={
311-
"max_tokens": n_predict,
312-
"messages": [
313-
{"role": "system", "content": "You are a coding assistant."},
314-
{"role": "user", "content": "say hello world with python"},
315-
],
316-
"tools": tools if tools else None,
317-
"tool_choice": tool_choice,
318-
"temperature": 0.0,
319-
"top_k": 1,
320-
"top_p": 1.0,
321-
})
322-
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
323-
choice = res.body["choices"][0]
324-
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
325-
326-
327-
@pytest.mark.slow
328-
@pytest.mark.parametrize("hf_repo,hf_file,template_override", [
329-
("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
330-
("bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
331-
("bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
332-
("bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
333-
("NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
334-
("NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
335-
("bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None),
336-
("bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
337-
("bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
338-
("bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
339-
])
340-
def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
341-
global server
342-
server.n_slots = 1
343-
server.jinja = True
344-
server.n_ctx = 8192
345-
server.n_predict = 128
346-
server.model_hf_repo = hf_repo
347-
server.model_hf_file = hf_file
348-
if template_override:
349-
(template_hf_repo, template_variant) = template_override
350-
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
351-
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_hf_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
352-
server.start(timeout_seconds=15*60)
353-
res = server.make_request("POST", "/chat/completions", data={
354-
"max_tokens": 256,
355-
"messages": [
356-
{"role": "user", "content": "What is the weather in Istanbul?"},
357-
],
358-
"tools": [WEATHER_TOOL],
359-
})
360-
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
361-
choice = res.body["choices"][0]
362-
tool_calls = choice["message"].get("tool_calls")
363-
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
364-
tool_call = tool_calls[0]
365-
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"]
366-
actual_arguments = json.loads(tool_call["function"]["arguments"])
367-
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
368-
location = actual_arguments["location"]
369-
assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}"
370-
assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
371-
372-
373-
@pytest.mark.slow
374-
@pytest.mark.parametrize("expected_arguments,hf_repo,hf_file,template_override", [
375-
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
376-
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
377-
('{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
378-
(None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
379-
(None, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
380-
(None, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
381-
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
382-
(None, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
383-
(None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
384-
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None),
385-
])
386-
def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
387-
global server
388-
server.n_slots = 1
389-
server.jinja = True
390-
server.n_ctx = 8192
391-
server.n_predict = 128
392-
server.model_hf_repo = hf_repo
393-
server.model_hf_file = hf_file
394-
if template_override:
395-
(template_hf_repo, template_variant) = template_override
396-
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
397-
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_hf_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
398-
server.start(timeout_seconds=15*60)
399-
res = server.make_request("POST", "/chat/completions", data={
400-
"max_tokens": 256,
401-
"messages": [
402-
{"role": "system", "content": "You are a coding assistant."},
403-
{"role": "user", "content": "say hello world with python"},
404-
# {"role": "user", "content": "Print a hello world message with python"},
405-
],
406-
"tools": [PYTHON_TOOL],
407-
})
408-
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
409-
choice = res.body["choices"][0]
410-
tool_calls = choice["message"].get("tool_calls")
411-
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
412-
tool_call = tool_calls[0]
413-
assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
414-
actual_arguments = tool_call["function"]["arguments"]
415-
if expected_arguments is not None:
416-
assert actual_arguments == expected_arguments
417-
else:
418-
actual_arguments = json.loads(actual_arguments)
419-
assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
420-
code = actual_arguments["code"]
421-
assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
422-
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'
423-
424-
425191
def test_logprobs():
426192
global server
427193
server.start()

0 commit comments

Comments
 (0)