Skip to content

Commit a774093

Browse files
author
ochafik
committed
tool-call: add server tests for llama 3.1
1 parent 9e366b3 commit a774093

File tree

3 files changed

+129
-16
lines changed

3 files changed

+129
-16
lines changed

common/tool-call.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
316316
tool_rules.push_back(
317317
builder.add_rule(
318318
name + "-call",
319-
"\"\\n{\\\"name\\\": " + name + "\\\", \\\"parameters\\\", \" " +
319+
"\"\\n{\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
320320
builder.add_schema(name + "-args", parameters) +
321321
" \"}\""));
322322
if (allow_content) {

examples/server/tests/features/steps/steps.py

Lines changed: 80 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def step_server_config(context, server_fqdn: str, server_port: str):
8080
context.temperature = None
8181
context.lora_file = None
8282
context.disable_ctx_shift = False
83+
context.use_jinja = False
84+
context.chat_template_file = None
8385

8486
context.tasks_result = []
8587
context.concurrent_tasks = []
@@ -159,6 +161,16 @@ def step_slot_save_path(context, slot_save_path: str):
159161
context.slot_save_path = slot_save_path
160162

161163

164+
@step('jinja templates are enabled')
165+
def step_use_jinja(context):
166+
context.use_jinja = True
167+
168+
169+
@step('chat template file {file}')
170+
def step_use_jinja(context, file):
171+
context.chat_template_file = file
172+
173+
162174
@step('using slot id {id_slot:d}')
163175
def step_id_slot(context, id_slot: int):
164176
context.id_slot = id_slot
@@ -369,7 +381,7 @@ def step_response_format(context, response_format):
369381
def step_tools(context, tools):
370382
context.tools = json.loads(tools)
371383

372-
@step('tool choice {tool_choice}')
384+
@step('a tool choice {tool_choice}')
373385
def step_tool_choice(context, tool_choice):
374386
context.tool_choice = tool_choice
375387

@@ -490,8 +502,11 @@ async def step_oai_chat_completions(context, api_error):
490502
expect_api_error = api_error == 'raised'
491503
seeds = await completions_seed(context, num_seeds=1)
492504
completion = await oai_chat_completions(context.prompts.pop(),
493-
seeds[0] if seeds is not None else seeds,
494-
context.system_prompt,
505+
seeds[0] if seeds else None,
506+
507+
context.system_prompt
508+
if hasattr(context, 'system_prompt') else None,
509+
495510
context.base_url,
496511
'/v1/chat',
497512
False,
@@ -631,6 +646,43 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
631646
assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests"
632647

633648

649+
@step('tool {expected_name} is called with arguments {expected_arguments}')
650+
@async_run_until_complete
651+
async def step_tool_called(context, expected_name, expected_arguments):
652+
n_completions = await gather_tasks_results(context)
653+
assert n_completions > 0
654+
655+
expected_name = expected_name if expected_name else None
656+
expected_arguments = json.loads(expected_arguments) if expected_arguments else None
657+
658+
def check(tool_calls):
659+
if tool_calls is None:
660+
assert expected_name is None and expected_arguments is None, f'expected_name = {expected_name}, expected_arguments = {expected_arguments}'
661+
else:
662+
assert len(tool_calls) == 1, f"tool calls: {tool_calls}"
663+
tool_call = tool_calls[0]
664+
actual_name = tool_call.name
665+
actual_arguments = json.loads(tool_call.arguments)
666+
assert expected_name == actual_name, f"tool name: {actual_name}, expected: {expected_name}"
667+
assert json.dumps(expected_arguments) == json.dumps(actual_arguments), f"tool arguments: {json.dumps(actual_arguments)}, expected: {json.dumps(expected_arguments)}"
668+
669+
for i in range(n_completions):
670+
assert_n_tokens_predicted(context.tasks_result.pop(), tool_calls_check=check)
671+
assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests"
672+
673+
@step('no tool is called')
674+
@async_run_until_complete
675+
async def step_tool_called(context):
676+
n_completions = await gather_tasks_results(context)
677+
assert n_completions > 0
678+
679+
def check(tool_calls):
680+
assert tool_calls is None
681+
682+
for i in range(n_completions):
683+
assert_n_tokens_predicted(context.tasks_result.pop(), tool_calls_check=check)
684+
assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests"
685+
634686
@step('embeddings are computed for')
635687
@async_run_until_complete
636688
async def step_compute_embedding(context):
@@ -1001,19 +1053,23 @@ async def oai_chat_completions(user_prompt,
10011053
print(f"Sending OAI Chat completions request: {user_prompt}")
10021054
# openai client always expects an api key
10031055
user_api_key = user_api_key if user_api_key is not None else 'nope'
1056+
assert isinstance(seed, int), f'seed: {seed}'
10041057
seed = seed if seed is not None else 42
1058+
10051059
enable_streaming = enable_streaming if enable_streaming is not None else False
1060+
messages = []
1061+
if system_prompt:
1062+
messages.append({
1063+
"role": "system",
1064+
"content": system_prompt,
1065+
})
1066+
if user_prompt:
1067+
messages.append({
1068+
"role": "user",
1069+
"content": user_prompt,
1070+
})
10061071
payload = {
1007-
"messages": [
1008-
{
1009-
"role": "system",
1010-
"content": system_prompt,
1011-
},
1012-
{
1013-
"role": "user",
1014-
"content": user_prompt,
1015-
}
1016-
],
1072+
"messages": messages,
10171073
"model": model,
10181074
"max_tokens": n_predict,
10191075
"stream": enable_streaming,
@@ -1115,6 +1171,7 @@ async def oai_chat_completions(user_prompt,
11151171
assert chat_completion.usage is not None
11161172
completion_response = {
11171173
'content': chat_completion.choices[0].message.content,
1174+
'tool_calls': chat_completion.choices[0].message.tool_calls,
11181175
'timings': {
11191176
'predicted_n': chat_completion.usage.completion_tokens,
11201177
'prompt_n': chat_completion.usage.prompt_tokens
@@ -1181,11 +1238,13 @@ async def request_oai_embeddings(input, seed,
11811238
return [e.embedding for e in oai_embeddings.data]
11821239

11831240

1184-
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
1241+
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None, tool_calls_check=None):
11851242
content = completion_response['content']
1243+
tool_calls = completion_response.get('tool_calls')
11861244
n_predicted = completion_response['timings']['predicted_n']
1187-
assert len(content) > 0, "no token predicted"
1245+
assert (content and len(content) > 0) or (tool_calls and len(tool_calls) > 0), "no token predicted"
11881246
if re_content is not None:
1247+
assert content
11891248
p = re.compile(re_content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL)
11901249
matches = p.finditer(content)
11911250
last_match = 0
@@ -1201,6 +1260,8 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
12011260
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
12021261
print(f"Checking completion response: {highlighted}")
12031262
assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```'
1263+
if tool_calls_check:
1264+
tool_calls_check(tool_calls)
12041265
if expected_predicted_n and expected_predicted_n > 0:
12051266
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
12061267
f' {n_predicted} <> {expected_predicted_n}')
@@ -1409,6 +1470,10 @@ def start_server_background(context):
14091470
server_args.extend(['--grp-attn-w', context.n_ga_w])
14101471
if context.debug:
14111472
server_args.append('--verbose')
1473+
if context.use_jinja:
1474+
server_args.append('--jinja')
1475+
if context.chat_template_file:
1476+
server_args.extend(['--chat-template-file', context.chat_template_file])
14121477
if context.lora_file:
14131478
server_args.extend(['--lora', context.lora_file])
14141479
if context.disable_ctx_shift:
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
@llama.cpp
2+
@server
3+
Feature: llama.cpp server
4+
5+
Background: Server startup
6+
Given a server listening on localhost:8080
7+
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
8+
And a model file test-model.gguf
9+
And a model alias tinyllama-2
10+
And BOS token is 1
11+
And 42 as server seed
12+
And 8192 KV cache size
13+
And 32 as batch size
14+
And 2 slots
15+
And 64 server max tokens to predict
16+
And prometheus compatible metrics exposed
17+
And jinja templates are enabled
18+
And chat template file ../../../tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja
19+
Then the server is starting
20+
Then the server is healthy
21+
22+
Scenario: Health
23+
Then the server is ready
24+
And all slots are idle
25+
26+
Scenario Outline: OAI Compatibility w/ required tool
27+
Given a model test
28+
And <n> max tokens to predict
29+
And a user prompt write a hello world in python
30+
And a tool choice <tool_choice>
31+
And tools <tools>
32+
Given an OAI compatible chat completions request with no api error
33+
Then tool <tool_name> is called with arguments <tool_arguments>
34+
35+
Examples: Prompts
36+
| n | tool_name | tool_arguments | tool_choice | tools |
37+
| 64 | test | {} | required | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
38+
| 16 | ipython | {"code": "it and "} | required | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
39+
40+
Scenario: OAI Compatibility w/ no tool
41+
Given a model test
42+
And 16 max tokens to predict
43+
And a user prompt write a hello world in python
44+
And a tool choice <tool_choice>
45+
And tools []
46+
Given an OAI compatible chat completions request with no api error
47+
Then no tool is called
48+

0 commit comments

Comments
 (0)