@@ -221,6 +221,23 @@ def test_chat_completion_with_timings_per_token():
221221 }
222222}
223223
224+ WEATHER_TOOL = {
225+ "type" :"function" ,
226+ "function" :{
227+ "name" :"get_current_weather" ,
228+ "description" :"Get the current weather in a given location" ,
229+ "parameters" :{
230+ "type" :"object" ,
231+ "properties" :{
232+ "location" :{
233+ "type" :"string" ,
234+ "description" :"The city and state, e.g. San Francisco, CA"
235+ }
236+ },
237+ "required" :["location" ]
238+ }
239+ }
240+ }
224241
225242@pytest .mark .parametrize ("template_name,tool,argument_key" , [
226243 ("meetkai-functionary-medium-v3.1" , TEST_TOOL , "success" ),
@@ -308,22 +325,76 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
308325
309326
310327@pytest .mark .slow
311- @pytest .mark .parametrize ("tool, expected_arguments,hf_repo,hf_file,template_override" , [
312- (PYTHON_TOOL , None , "bartowski/functionary-small-v3.2- GGUF" , "functionary-small-v3.2-Q8_0.gguf" , ( "meetkai-functionary-medium-v3.2 " , None ) ),
313- (PYTHON_TOOL , None , "bartowski/gemma-2-2b-it-GGUF" , "gemma-2-2b-it-Q4_K_M.gguf" , None ),
314- (PYTHON_TOOL , None , "bartowski/Phi-3.5-mini-instruct-GGUF" , "Phi-3.5-mini-instruct-Q4_K_M.gguf" , None ),
315- (PYTHON_TOOL , None , "bartowski/Qwen2.5-7B-Instruct-GGUF" , "Qwen2.5-7B-Instruct-Q4_K_M.gguf" , None ),
316- (PYTHON_TOOL , None , "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF" , "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf" , ("NousResearch/Hermes-2-Pro-Llama-3-8B" , "tool_use" )),
317- (PYTHON_TOOL , None , "NousResearch/Hermes-3-Llama-3.1-8B-GGUF" , "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf" , ("NousResearch-Hermes-3-Llama-3.1-8B" , "tool_use" )),
318- (PYTHON_TOOL , None , "bartowski/Mistral-Nemo-Instruct-2407-GGUF" , "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf" , None ),
328+ @pytest .mark .parametrize ("expected_arguments,hf_repo,hf_file,template_override" , [
329+ (None , "lmstudio-community/Meta-Llama-3.1-8B-Instruct- GGUF" , "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf " , None ),
330+ (None , "bartowski/gemma-2-2b-it-GGUF" , "gemma-2-2b-it-Q4_K_M.gguf" , None ),
331+ (None , "bartowski/Phi-3.5-mini-instruct-GGUF" , "Phi-3.5-mini-instruct-Q4_K_M.gguf" , None ),
332+ (None , "bartowski/Qwen2.5-7B-Instruct-GGUF" , "Qwen2.5-7B-Instruct-Q4_K_M.gguf" , None ),
333+ (None , "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF" , "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf" , ("NousResearch/Hermes-2-Pro-Llama-3-8B" , "tool_use" )),
334+ (None , "NousResearch/Hermes-3-Llama-3.1-8B-GGUF" , "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf" , ("NousResearch-Hermes-3-Llama-3.1-8B" , "tool_use" )),
335+ (None , "bartowski/Mistral-Nemo-Instruct-2407-GGUF" , "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf" , None ),
319336 # TODO: fix these models
320- ( PYTHON_TOOL , '{"code":"print("}' , "lmstudio-community/Meta- Llama-3.1-8B -Instruct-GGUF" , "Meta- Llama-3.1-8B -Instruct-Q4_K_M .gguf" , None ),
321- # (PYTHON_TOOL, None, "bartowski/Llama-3.2-3B -Instruct-GGUF", "Llama-3.2-3B -Instruct-Q4_K_M .gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
322- # (PYTHON_TOOL, None, "bartowski/Llama-3 .2-1B-Instruct- GGUF", "Llama-3 .2-1B-Instruct-Q4_K_M .gguf", ("meta-llama-Llama-3.2-3B-Instruct ", None)),
337+ # (None , "bartowski/ Llama-3.2-3B -Instruct-GGUF", "Llama-3.2-3B -Instruct-Q6_K_L .gguf", ("meta-llama-Llama-3.2-3B-Instruct", None) ),
338+ # (None, "bartowski/Llama-3.2-1B -Instruct-GGUF", "Llama-3.2-1B -Instruct-Q6_K_L .gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
339+ # (None, "bartowski/functionary-small-v3 .2-GGUF", "functionary-small-v3 .2-Q8_0 .gguf", ("meetkai-functionary-medium-v3.2 ", None)),
323340])
324- def test_hello_world_tool_call ( tool : dict , expected_arguments : str | None , hf_repo : str , hf_file : str , template_override : Tuple [str , str | None ] | None ):
341+ def test_weather_tool_call ( expected_arguments : str | None , hf_repo : str , hf_file : str , template_override : Tuple [str , str | None ] | None ):
325342 global server
326- server .n_slots = 2
343+ server .n_slots = 1
344+ server .jinja = True
345+ server .n_ctx = 8192
346+ server .n_predict = 128
347+ server .model_hf_repo = hf_repo
348+ server .model_hf_file = hf_file
349+ if template_override :
350+ (template_hf_repo , template_variant ) = template_override
351+ server .chat_template_file = f"../../../tests/chat/templates/{ template_hf_repo .replace ('/' , '' ) + ('-' + template_variant if template_variant else '' )} .jinja"
352+ assert os .path .exists (server .chat_template_file ), f"Template file { server .chat_template_file } does not exist. Run `python scripts/get_hf_chat_template.py { template_hf_repo } { template_variant } > { server .chat_template_file } ` to download the template."
353+ server .start (timeout_seconds = 15 * 60 )
354+ res = server .make_request ("POST" , "/chat/completions" , data = {
355+ "max_tokens" : 256 ,
356+ "messages" : [
357+ {"role" : "user" , "content" : "What is the weather in Istanbul?" },
358+ ],
359+ "tools" : [WEATHER_TOOL ],
360+ # "temperature": 0.5,
361+ # "top_k": 10,
362+ # "top_p": 0.9,
363+ })
364+ assert res .status_code == 200 , f"Expected status code 200, got { res .status_code } "
365+ choice = res .body ["choices" ][0 ]
366+ tool_calls = choice ["message" ].get ("tool_calls" )
367+ assert tool_calls and len (tool_calls ) == 1 , f'Expected 1 tool call in { choice ["message" ]} '
368+ tool_call = tool_calls [0 ]
369+ assert tool_call ["function" ]["name" ] == WEATHER_TOOL ["function" ]["name" ]
370+ actual_arguments = tool_call ["function" ]["arguments" ]
371+ if expected_arguments is not None :
372+ assert actual_arguments == expected_arguments
373+ else :
374+ actual_arguments = json .loads (actual_arguments )
375+ assert 'location' in actual_arguments , f"location not found in { json .dumps (actual_arguments )} "
376+ location = actual_arguments ["location" ]
377+ assert isinstance (location , str ), f"Expected location to be a string, got { type (location )} : { json .dumps (location )} "
378+ assert re .match ('^Istanbul(, (TR|Turkey|Türkiye))?$' , location ), f'Expected Istanbul for location, got { location } '
379+
380+
381+ @pytest .mark .slow
382+ @pytest .mark .parametrize ("expected_arguments,hf_repo,hf_file,template_override" , [
383+ ('{"code":"print("}' , "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF" , "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf" , None ),
384+ (None , "bartowski/functionary-small-v3.2-GGUF" , "functionary-small-v3.2-Q8_0.gguf" , ("meetkai-functionary-medium-v3.2" , None )),
385+ (None , "bartowski/gemma-2-2b-it-GGUF" , "gemma-2-2b-it-Q4_K_M.gguf" , None ),
386+ (None , "bartowski/Phi-3.5-mini-instruct-GGUF" , "Phi-3.5-mini-instruct-Q4_K_M.gguf" , None ),
387+ (None , "bartowski/Qwen2.5-7B-Instruct-GGUF" , "Qwen2.5-7B-Instruct-Q4_K_M.gguf" , None ),
388+ (None , "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF" , "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf" , ("NousResearch/Hermes-2-Pro-Llama-3-8B" , "tool_use" )),
389+ (None , "NousResearch/Hermes-3-Llama-3.1-8B-GGUF" , "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf" , ("NousResearch-Hermes-3-Llama-3.1-8B" , "tool_use" )),
390+ (None , "bartowski/Mistral-Nemo-Instruct-2407-GGUF" , "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf" , None ),
391+ # TODO: fix these models
392+ # (None, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q6_K_L.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
393+ # (None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q6_K_L.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
394+ ])
395+ def test_hello_world_tool_call (expected_arguments : str | None , hf_repo : str , hf_file : str , template_override : Tuple [str , str | None ] | None ):
396+ global server
397+ server .n_slots = 1
327398 server .jinja = True
328399 server .n_ctx = 8192
329400 server .n_predict = 128
@@ -341,7 +412,7 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: str | None, hf_re
341412 {"role" : "user" , "content" : "say hello world with python" },
342413 # {"role": "user", "content": "Print a hello world message with python"},
343414 ],
344- "tools" : [tool ],
415+ "tools" : [PYTHON_TOOL ],
345416 "temperature" : 0.5 ,
346417 "top_k" : 10 ,
347418 "top_p" : 0.9 ,
@@ -351,10 +422,7 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: str | None, hf_re
351422 tool_calls = choice ["message" ].get ("tool_calls" )
352423 assert tool_calls and len (tool_calls ) == 1 , f'Expected 1 tool call in { choice ["message" ]} '
353424 tool_call = tool_calls [0 ]
354- if tool ["type" ] == "function" :
355- assert tool ["function" ]["name" ] == tool_call ["function" ]["name" ]
356- elif tool ["type" ] == "code_interpreter" :
357- assert re .match ('i?python' , tool_call ["function" ]["name" ])
425+ assert tool_call ["function" ]["name" ] == PYTHON_TOOL ["function" ]["name" ]
358426 actual_arguments = tool_call ["function" ]["arguments" ]
359427 if expected_arguments is not None :
360428 assert actual_arguments == expected_arguments
0 commit comments