@@ -80,6 +80,8 @@ def step_server_config(context, server_fqdn: str, server_port: str):
8080 context .temperature = None
8181 context .lora_file = None
8282 context .disable_ctx_shift = False
83+ context .use_jinja = False
84+ context .chat_template_file = None
8385
8486 context .tasks_result = []
8587 context .concurrent_tasks = []
@@ -159,6 +161,16 @@ def step_slot_save_path(context, slot_save_path: str):
159161 context .slot_save_path = slot_save_path
160162
161163
164+ @step ('jinja templates are enabled' )
165+ def step_use_jinja (context ):
166+ context .use_jinja = True
167+
168+
169+ @step ('chat template file {file}' )
170+ def step_use_jinja (context , file ):
171+ context .chat_template_file = file
172+
173+
162174@step ('using slot id {id_slot:d}' )
163175def step_id_slot (context , id_slot : int ):
164176 context .id_slot = id_slot
@@ -369,7 +381,7 @@ def step_response_format(context, response_format):
369381def step_tools (context , tools ):
370382 context .tools = json .loads (tools )
371383
372- @step ('tool choice {tool_choice}' )
384+ @step ('a tool choice {tool_choice}' )
373385def step_tool_choice (context , tool_choice ):
374386 context .tool_choice = tool_choice
375387
@@ -490,8 +502,11 @@ async def step_oai_chat_completions(context, api_error):
490502 expect_api_error = api_error == 'raised'
491503 seeds = await completions_seed (context , num_seeds = 1 )
492504 completion = await oai_chat_completions (context .prompts .pop (),
493- seeds [0 ] if seeds is not None else seeds ,
494- context .system_prompt ,
505+ seeds [0 ] if seeds else None ,
506+
507+ context .system_prompt
508+ if hasattr (context , 'system_prompt' ) else None ,
509+
495510 context .base_url ,
496511 '/v1/chat' ,
497512 False ,
@@ -631,6 +646,43 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
631646 assert len (context .concurrent_tasks ) == 0 , f"{ len (context .concurrent_tasks )} pending requests"
632647
633648
649+ @step ('tool {expected_name} is called with arguments {expected_arguments}' )
650+ @async_run_until_complete
651+ async def step_tool_called (context , expected_name , expected_arguments ):
652+ n_completions = await gather_tasks_results (context )
653+ assert n_completions > 0
654+
655+ expected_name = expected_name if expected_name else None
656+ expected_arguments = json .loads (expected_arguments ) if expected_arguments else None
657+
658+ def check (tool_calls ):
659+ if tool_calls is None :
660+ assert expected_name is None and expected_arguments is None , f'expected_name = { expected_name } , expected_arguments = { expected_arguments } '
661+ else :
662+ assert len (tool_calls ) == 1 , f"tool calls: { tool_calls } "
663+ tool_call = tool_calls [0 ]
664+ actual_name = tool_call .name
665+ actual_arguments = json .loads (tool_call .arguments )
666+ assert expected_name == actual_name , f"tool name: { actual_name } , expected: { expected_name } "
667+ assert json .dumps (expected_arguments ) == json .dumps (actual_arguments ), f"tool arguments: { json .dumps (actual_arguments )} , expected: { json .dumps (expected_arguments )} "
668+
669+ for i in range (n_completions ):
670+ assert_n_tokens_predicted (context .tasks_result .pop (), tool_calls_check = check )
671+ assert len (context .concurrent_tasks ) == 0 , f"{ len (context .concurrent_tasks )} pending requests"
672+
673+ @step ('no tool is called' )
674+ @async_run_until_complete
675+ async def step_tool_called (context ):
676+ n_completions = await gather_tasks_results (context )
677+ assert n_completions > 0
678+
679+ def check (tool_calls ):
680+ assert tool_calls is None
681+
682+ for i in range (n_completions ):
683+ assert_n_tokens_predicted (context .tasks_result .pop (), tool_calls_check = check )
684+ assert len (context .concurrent_tasks ) == 0 , f"{ len (context .concurrent_tasks )} pending requests"
685+
634686@step ('embeddings are computed for' )
635687@async_run_until_complete
636688async def step_compute_embedding (context ):
@@ -1001,19 +1053,23 @@ async def oai_chat_completions(user_prompt,
10011053 print (f"Sending OAI Chat completions request: { user_prompt } " )
10021054 # openai client always expects an api key
10031055 user_api_key = user_api_key if user_api_key is not None else 'nope'
1056+ assert isinstance (seed , int ), f'seed: { seed } '
10041057 seed = seed if seed is not None else 42
1058+
10051059 enable_streaming = enable_streaming if enable_streaming is not None else False
1060+ messages = []
1061+ if system_prompt :
1062+ messages .append ({
1063+ "role" : "system" ,
1064+ "content" : system_prompt ,
1065+ })
1066+ if user_prompt :
1067+ messages .append ({
1068+ "role" : "user" ,
1069+ "content" : user_prompt ,
1070+ })
10061071 payload = {
1007- "messages" : [
1008- {
1009- "role" : "system" ,
1010- "content" : system_prompt ,
1011- },
1012- {
1013- "role" : "user" ,
1014- "content" : user_prompt ,
1015- }
1016- ],
1072+ "messages" : messages ,
10171073 "model" : model ,
10181074 "max_tokens" : n_predict ,
10191075 "stream" : enable_streaming ,
@@ -1115,6 +1171,7 @@ async def oai_chat_completions(user_prompt,
11151171 assert chat_completion .usage is not None
11161172 completion_response = {
11171173 'content' : chat_completion .choices [0 ].message .content ,
1174+ 'tool_calls' : chat_completion .choices [0 ].message .tool_calls ,
11181175 'timings' : {
11191176 'predicted_n' : chat_completion .usage .completion_tokens ,
11201177 'prompt_n' : chat_completion .usage .prompt_tokens
@@ -1181,11 +1238,13 @@ async def request_oai_embeddings(input, seed,
11811238 return [e .embedding for e in oai_embeddings .data ]
11821239
11831240
1184- def assert_n_tokens_predicted (completion_response , expected_predicted_n = None , re_content = None ):
1241+ def assert_n_tokens_predicted (completion_response , expected_predicted_n = None , re_content = None , tool_calls_check = None ):
11851242 content = completion_response ['content' ]
1243+ tool_calls = completion_response .get ('tool_calls' )
11861244 n_predicted = completion_response ['timings' ]['predicted_n' ]
1187- assert len (content ) > 0 , "no token predicted"
1245+ assert ( content and len (content ) > 0 ) or ( tool_calls and len ( tool_calls ) > 0 ) , "no token predicted"
11881246 if re_content is not None :
1247+ assert content
11891248 p = re .compile (re_content , flags = RegexFlag .IGNORECASE | RegexFlag .MULTILINE | RegexFlag .DOTALL )
11901249 matches = p .finditer (content )
11911250 last_match = 0
@@ -1201,6 +1260,8 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
12011260 if 'DEBUG' in os .environ and os .environ ['DEBUG' ] == 'ON' :
12021261 print (f"Checking completion response: { highlighted } " )
12031262 assert last_match > 0 , f'/{ re_content } / must match ```{ highlighted } ```'
1263+ if tool_calls_check :
1264+ tool_calls_check (tool_calls )
12041265 if expected_predicted_n and expected_predicted_n > 0 :
12051266 assert n_predicted == expected_predicted_n , (f'invalid number of tokens predicted:'
12061267 f' { n_predicted } <> { expected_predicted_n } ' )
@@ -1409,6 +1470,10 @@ def start_server_background(context):
14091470 server_args .extend (['--grp-attn-w' , context .n_ga_w ])
14101471 if context .debug :
14111472 server_args .append ('--verbose' )
1473+ if context .use_jinja :
1474+ server_args .append ('--jinja' )
1475+ if context .chat_template_file :
1476+ server_args .extend (['--chat-template-file' , context .chat_template_file ])
14121477 if context .lora_file :
14131478 server_args .extend (['--lora' , context .lora_file ])
14141479 if context .disable_ctx_shift :
0 commit comments