@@ -226,23 +226,31 @@ def test_chat_completion_with_timings_per_token():
226226}
227227
228228
229- @pytest .mark .parametrize ("template_name,n_predict,tool,argument_key" , [
230- ("meetkai-functionary-medium-v3.1" , 128 , TEST_TOOL , "success" ),
231- ("meetkai-functionary-medium-v3.1" , 128 , PYTHON_TOOL , "code" ),
232- ("meetkai-functionary-medium-v3.2" , 128 , TEST_TOOL , "success" ),
233- ("meetkai-functionary-medium-v3.2" , 128 , PYTHON_TOOL , "code" ),
234- ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use" , 128 , TEST_TOOL , "success" ),
235- ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use" , 128 , PYTHON_TOOL , "code" ),
236- ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use" , 128 , TEST_TOOL , "success" ),
237- ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use" , 128 , PYTHON_TOOL , "code" ),
238- ("meta-llama-Meta-Llama-3.1-8B-Instruct" , 128 , TEST_TOOL , "success" ),
239- ("meta-llama-Meta-Llama-3.1-8B-Instruct" , 128 , PYTHON_TOOL , "code" ),
240- ("meta-llama-Llama-3.2-3B-Instruct" , 128 , TEST_TOOL , "success" ),
241- ("meta-llama-Llama-3.2-3B-Instruct" , 128 , PYTHON_TOOL , "code" ),
242- ("mistralai-Mistral-Nemo-Instruct-2407" , 128 , TEST_TOOL , "success" ),
243- ("mistralai-Mistral-Nemo-Instruct-2407" , 128 , PYTHON_TOOL , "code" ),
229+ @pytest .mark .parametrize ("template_name,tool,argument_key" , [
230+ ("meetkai-functionary-medium-v3.1" , TEST_TOOL , "success" ),
231+ ("meetkai-functionary-medium-v3.1" , PYTHON_TOOL , None ),
232+ ("meetkai-functionary-medium-v3.1" , CODE_INTEPRETER_TOOL , None ),
233+ ("meetkai-functionary-medium-v3.2" , TEST_TOOL , "success" ),
234+ ("meetkai-functionary-medium-v3.2" , PYTHON_TOOL , None ),
235+ ("meetkai-functionary-medium-v3.2" , CODE_INTEPRETER_TOOL , None ),
236+ ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use" , TEST_TOOL , "success" ),
237+ ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use" , PYTHON_TOOL , None ),
238+ ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use" , CODE_INTEPRETER_TOOL , None ),
239+ ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use" , TEST_TOOL , "success" ),
240+ ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use" , PYTHON_TOOL , None ),
241+ ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use" , CODE_INTEPRETER_TOOL , None ),
242+ ("meta-llama-Meta-Llama-3.1-8B-Instruct" , TEST_TOOL , "success" ),
243+ ("meta-llama-Meta-Llama-3.1-8B-Instruct" , PYTHON_TOOL , None ),
244+ ("meta-llama-Meta-Llama-3.1-8B-Instruct" , CODE_INTEPRETER_TOOL , None ),
245+ ("meta-llama-Llama-3.2-3B-Instruct" , TEST_TOOL , "success" ),
246+ ("meta-llama-Llama-3.2-3B-Instruct" , PYTHON_TOOL , None ),
247+ # # ("meta-llama-Llama-3.2-3B-Instruct", CODE_INTEPRETER_TOOL, None),
248+ ("mistralai-Mistral-Nemo-Instruct-2407" , TEST_TOOL , "success" ),
249+ ("mistralai-Mistral-Nemo-Instruct-2407" , PYTHON_TOOL , None ),
250+ ("mistralai-Mistral-Nemo-Instruct-2407" , CODE_INTEPRETER_TOOL , None ),
244251])
245- def test_completion_with_required_tool (template_name : str , n_predict : int , tool : dict , argument_key : str ):
252+ def test_completion_with_required_tool (template_name : str , tool : dict , argument_key : str | None ):
253+ n_predict = 512
246254 global server
247255 # server = ServerPreset.stories15m_moe()
248256 server .jinja = True
@@ -267,9 +275,13 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool:
267275 tool_calls = choice ["message" ].get ("tool_calls" )
268276 assert tool_calls and len (tool_calls ) == 1 , f'Expected 1 tool call in { choice ["message" ]} '
269277 tool_call = tool_calls [0 ]
270- assert tool ["function" ]["name" ] == tool_call ["function" ]["name" ]
271- actual_arguments = json .loads (tool_call ["function" ]["arguments" ])
272- assert argument_key in actual_arguments , f"tool arguments: { json .dumps (actual_arguments )} , expected: { argument_key } "
278+ expected_function_name = "python" if tool ["type" ] == "code_interpreter" else tool ["function" ]["name" ]
279+ assert expected_function_name == tool_call ["function" ]["name" ]
280+ actual_arguments = tool_call ["function" ]["arguments" ]
281+ assert isinstance (actual_arguments , str )
282+ if argument_key is not None :
283+ actual_arguments = json .loads (actual_arguments )
284+ assert argument_key in actual_arguments , f"tool arguments: { json .dumps (actual_arguments )} , expected: { argument_key } "
273285
274286
275287@pytest .mark .parametrize ("template_name,n_predict,tools,tool_choice" , [
0 commit comments