44from pathlib import Path
55
66import pytest
7- from httpx import AsyncClient
7+ from httpx import AsyncClient , Response
88from pydantic import ValidationError
99from sqlalchemy import select
1010from sqlalchemy .orm import selectinload
1313from api .thread .models .thread import Thread
1414from core .message .message_chunk import (
1515 AddMessageChunk ,
16+ ChunkType ,
17+ ErrorChunk ,
1618 FinalThreadChunk ,
1719 StartThreadChunk ,
1820 StreamEndChunk ,
1921 StreamStartChunk ,
2022 ToolCallChunk ,
2123)
24+ from core .message .message_errors import ErrorCode
2225from core .message .role import Role
2326from db .models .message import Message
2427from e2e .conftest import AuthenticatedClient , DatabaseSession , auth_headers_for_user
3134IS_CI = os .getenv ("CI" , "false" ) == "true"
3235
3336
37+ def _get_dict_lines_from_response (response : Response ):
38+ text_lines = response .text .splitlines ()
39+ lines = [json .loads (line ) for line in text_lines ]
40+
41+ return lines
42+
43+
44+ def _get_lines_without_deltas (response : Response ):
45+ lines = _get_dict_lines_from_response (response )
46+ lines_without_stream = [line for line in lines if line ["type" ] != ChunkType .MODEL_RESPONSE .value ]
47+
48+ return lines_without_stream
49+
50+
3451async def test_calls_user_tools (client : AsyncClient , auth_user : AuthenticatedClient , db_session : DatabaseSession ):
35- tool_name = "get_current_weather"
36- tool_definition = CreateToolDefinition (
37- name = tool_name ,
52+ weather_tool_name = "get_current_weather"
53+ weather_tool_definition = CreateToolDefinition (
54+ name = weather_tool_name ,
3855 description = "Get the current weather in a given location" ,
3956 parameters = ParameterDef (
4057 type = "object" ,
@@ -47,7 +64,21 @@ async def test_calls_user_tools(client: AsyncClient, auth_user: AuthenticatedCli
4764 },
4865 ),
4966 )
50- tool_definitions = f"[{ tool_definition .model_dump_json ()} ]"
67+
68+ location_tool_name = "get_user_location"
69+ location_tool_definition = CreateToolDefinition (
70+ name = location_tool_name ,
71+ description = "Get the user's location" ,
72+ parameters = ParameterDef (
73+ type = "object" ,
74+ properties = {
75+ "city" : ParameterDef (type = "string" , description = "The user's city" , default = {"string_value" : "Boston" }),
76+ "state" : ParameterDef (type = "string" , description = "The user's state" , default = {"string_value" : "MA" }),
77+ },
78+ ),
79+ )
80+
81+ tool_definitions = f"[{ weather_tool_definition .model_dump_json ()} , { location_tool_definition .model_dump_json ()} ]"
5182 chat_request = UserChatRequest (
5283 content = "test tool calling" ,
5384 model = "test-model" ,
@@ -60,17 +91,18 @@ async def test_calls_user_tools(client: AsyncClient, auth_user: AuthenticatedCli
6091
6192 assert_ok_response (response = response )
6293
63- lines = [ json . loads ( line ) for line in response . text . splitlines ()]
94+ lines = _get_dict_lines_from_response ( response )
6495
65- assert len (lines ) == 6
96+ assert len (lines ) == 7
6697 StreamStartChunk .model_validate (lines [0 ])
6798 starting_thread = StartThreadChunk .model_validate (lines [1 ])
6899 AddMessageChunk .model_validate (lines [2 ])
69- tool_call_chunk = ToolCallChunk .model_validate (lines [3 ])
100+ weather_tool_call_chunk = ToolCallChunk .model_validate (lines [3 ])
101+ location_tool_call_chunk = ToolCallChunk .model_validate (lines [4 ])
70102 finished_thread = FinalThreadChunk .model_validate (lines [- 2 ])
71103 StreamEndChunk .model_validate (lines [- 1 ])
72104
73- assert tool_call_chunk .tool_name == tool_name
105+ assert weather_tool_call_chunk .tool_name == weather_tool_name
74106 assert len (starting_thread .messages ) == 2
75107 assert finished_thread .id == starting_thread .id
76108 assert len (finished_thread .messages ) == 3
@@ -79,8 +111,8 @@ async def test_calls_user_tools(client: AsyncClient, auth_user: AuthenticatedCli
79111 assert finished_thread .messages [1 ].role == Role .User
80112 assert finished_thread .messages [2 ].role == Role .Assistant
81113 assert finished_thread .messages [2 ].tool_calls
82- assert len (finished_thread .messages [2 ].tool_calls ) == 1
83- assert finished_thread .messages [2 ].tool_calls [0 ].tool_name == tool_name
114+ assert len (finished_thread .messages [2 ].tool_calls ) == 2
115+ assert finished_thread .messages [2 ].tool_calls [0 ].tool_name == weather_tool_name
84116
85117 async with db_session () as session , session .begin ():
86118 message_query = (
@@ -105,7 +137,7 @@ async def test_calls_user_tools(client: AsyncClient, auth_user: AuthenticatedCli
105137 model = "test-model" ,
106138 enable_tool_calling = True ,
107139 parent = finished_thread .messages [2 ].id ,
108- tool_call_id = tool_call_chunk .tool_call_id ,
140+ tool_call_id = weather_tool_call_chunk .tool_call_id ,
109141 ).model_dump (exclude_none = True , exclude_computed_fields = True )
110142
111143 tool_request ["toolDefinitions" ] = tool_definitions
@@ -117,20 +149,54 @@ async def test_calls_user_tools(client: AsyncClient, auth_user: AuthenticatedCli
117149 lines = [json .loads (line ) for line in tool_response .text .splitlines ()]
118150
119151 StreamStartChunk .model_validate (lines [0 ])
120- tool_response_chunk = AddMessageChunk .model_validate (lines [1 ])
152+ weather_tool_response_chunk = AddMessageChunk .model_validate (lines [1 ])
121153 # ...streaming response...
122- final_thread_chunk = FinalThreadChunk .model_validate (lines [- 2 ])
154+ final_thread_with_pending_tools_chunk = FinalThreadChunk .model_validate (lines [- 2 ])
123155 StreamEndChunk .model_validate (lines [- 1 ])
124156
125- assert tool_response_chunk .messages [0 ].tool_calls , "There were no tool calls in the tool result response"
126- assert tool_response_chunk .messages [0 ].tool_calls [0 ].tool_call_id == tool_call_chunk .tool_call_id
127- assert tool_response_chunk .messages [0 ].content == "Sunny"
157+ assert weather_tool_response_chunk .messages [0 ].tool_calls , "There were no tool calls in the tool result response"
158+ assert weather_tool_response_chunk .messages [0 ].tool_calls [0 ].tool_call_id == weather_tool_call_chunk .tool_call_id
159+ assert weather_tool_response_chunk .messages [0 ].content == "Sunny"
160+
161+ assert len (final_thread_with_pending_tools_chunk .messages ) == 1
162+ assert final_thread_with_pending_tools_chunk .messages [0 ].role == Role .ToolResponse
163+ assert final_thread_with_pending_tools_chunk .messages [0 ].tool_calls
164+ assert len (final_thread_with_pending_tools_chunk .messages [0 ].tool_calls ) == 1
165+
166+ tool_request = ToolResponseChatRequest (
167+ content = '{"city": "Boston", "state": "MA"}' ,
168+ model = "test-model" ,
169+ enable_tool_calling = True ,
170+ parent = final_thread_with_pending_tools_chunk .messages [- 1 ].id ,
171+ tool_call_id = location_tool_call_chunk .tool_call_id ,
172+ ).model_dump (exclude_none = True , exclude_computed_fields = True )
173+
174+ tool_request ["toolDefinitions" ] = tool_definitions
175+
176+ tool_response = await client .post (CHAT_ENDPOINT , data = tool_request , headers = auth_headers_for_user (auth_user ))
177+
178+ assert_ok_response (response = tool_response )
179+
180+ lines = [json .loads (line ) for line in tool_response .text .splitlines ()]
181+
182+ StreamStartChunk .model_validate (lines [0 ])
183+ location_tool_response_chunk = AddMessageChunk .model_validate (lines [1 ])
184+ # ...streaming response...
185+ final_thread_with_pending_tools_chunk = FinalThreadChunk .model_validate (lines [- 2 ])
186+ StreamEndChunk .model_validate (lines [- 1 ])
128187
129- assert len (final_thread_chunk .messages ) == 2
130- assert final_thread_chunk .messages [0 ].role == Role .ToolResponse
131- assert final_thread_chunk .messages [1 ].role == Role .Assistant
132- assert final_thread_chunk .messages [0 ].tool_calls
133- assert len (final_thread_chunk .messages [0 ].tool_calls ) == 1
188+ assert location_tool_response_chunk .messages [0 ].tool_calls , "There were no tool calls in the tool result response"
189+ assert location_tool_response_chunk .messages [0 ].tool_calls [0 ].tool_call_id == location_tool_call_chunk .tool_call_id
190+ assert location_tool_response_chunk .messages [0 ].content == '{"city": "Boston", "state": "MA"}'
191+
192+ assert len (final_thread_with_pending_tools_chunk .messages ) == 2
193+ assert final_thread_with_pending_tools_chunk .messages [0 ].role == Role .ToolResponse
194+ assert final_thread_with_pending_tools_chunk .messages [0 ].tool_calls
195+ assert len (final_thread_with_pending_tools_chunk .messages [0 ].tool_calls ) == 1
196+ assert final_thread_with_pending_tools_chunk .messages [1 ].role == Role .Assistant
197+ assert final_thread_with_pending_tools_chunk .messages [1 ].content , (
198+ "The final response with all tool calls should have content"
199+ )
134200
135201
136202async def test_calls_mcp_tools (client : AsyncClient , auth_user : AuthenticatedClient , db_session : DatabaseSession ):
@@ -146,7 +212,9 @@ async def test_calls_mcp_tools(client: AsyncClient, auth_user: AuthenticatedClie
146212
147213 assert_ok_response (response = response )
148214
149- lines = [json .loads (line ) for line in response .text .splitlines ()]
215+ lines = _get_lines_without_deltas (response )
216+
217+ assert len (lines ) == 8
150218
151219 StreamStartChunk .model_validate (lines [0 ])
152220 starting_thread = StartThreadChunk .model_validate (lines [1 ])
@@ -190,6 +258,68 @@ async def test_calls_mcp_tools(client: AsyncClient, auth_user: AuthenticatedClie
190258 assert message_in_db .children [0 ].id == finished_thread .messages [2 ].id
191259
192260
261+ async def test_calls_a_failing_tool (client : AsyncClient , anon_user : AuthenticatedClient , db_session : DatabaseSession ):
262+ tool_name = "always_fails"
263+ chat_request = UserChatRequest (
264+ content = "test failing tool calling" ,
265+ model = "test-model" ,
266+ enable_tool_calling = True ,
267+ selected_tools = [tool_name ],
268+ ).model_dump (exclude_none = True , exclude_computed_fields = True )
269+
270+ response = await client .post (CHAT_ENDPOINT , data = chat_request , headers = auth_headers_for_user (anon_user ))
271+
272+ assert_ok_response (response = response )
273+
274+ lines = _get_lines_without_deltas (response )
275+
276+ assert len (lines ) == 7
277+
278+ StreamStartChunk .model_validate (lines [0 ])
279+ starting_thread = StartThreadChunk .model_validate (lines [1 ])
280+ tool_call_chunk = ToolCallChunk .model_validate (lines [3 ])
281+ error_chunk = ErrorChunk .model_validate (lines [4 ])
282+ finished_thread = FinalThreadChunk .model_validate (lines [- 2 ])
283+ StreamEndChunk .model_validate (lines [- 1 ])
284+
285+ assert tool_call_chunk .tool_name == tool_name
286+ assert error_chunk .error_code == ErrorCode .TOOL_CALL_ERROR
287+ assert len (starting_thread .messages ) == 2
288+ assert finished_thread .id == starting_thread .id
289+ assert len (finished_thread .messages ) == 3
290+
291+ assert finished_thread .messages [0 ].role == Role .System
292+ assert finished_thread .messages [1 ].role == Role .User
293+
294+ assistant_message = finished_thread .messages [2 ]
295+ assert assistant_message .role == Role .Assistant
296+
297+ assert assistant_message .tool_calls
298+ assert len (assistant_message .tool_calls ) == 1 , "There were no tool calls on the intended tool call message"
299+ assert assistant_message .tool_calls [0 ].tool_name == tool_name
300+ assert assistant_message .error_code == ErrorCode .TOOL_CALL_ERROR
301+
302+ async with db_session () as session , session .begin ():
303+ message_query = (
304+ select (Message )
305+ .where (Message .id == finished_thread .messages [1 ].id )
306+ .options (
307+ selectinload (Message .children ),
308+ selectinload (Message .parent_ ),
309+ )
310+ )
311+ message_in_db_result = await session .scalars (message_query )
312+ message_in_db = message_in_db_result .one ()
313+
314+ assert message_in_db .parent_ is not None and message_in_db .parent_ .id == finished_thread .messages [0 ].id , ( # noqa: PT018
315+ "User message did not get its parent set correctly in the DB"
316+ )
317+ assert message_in_db .children
318+ assert message_in_db .children [0 ].id == finished_thread .messages [2 ].id
319+
320+ assert message_in_db .children [0 ].error_code == ErrorCode .TOOL_CALL_ERROR
321+
322+
193323async def test_does_not_call_tools (client : AsyncClient , anon_user : AuthenticatedClient ):
194324 tool_name = "get_current_weather"
195325 tool_definition = CreateToolDefinition (
@@ -219,7 +349,7 @@ async def test_does_not_call_tools(client: AsyncClient, anon_user: Authenticated
219349
220350 assert_ok_response (response = response )
221351
222- lines = [ json . loads ( line ) for line in response . text . splitlines ()]
352+ lines = _get_dict_lines_from_response ( response )
223353
224354 for line in lines :
225355 with pytest .raises (ValidationError ):
@@ -243,7 +373,7 @@ async def test_makes_a_thread_with_parent(
243373
244374 assert_ok_response (response = response )
245375
246- lines = [ json . loads ( line ) for line in response . text . splitlines ()]
376+ lines = _get_dict_lines_from_response ( response )
247377
248378 assert len (lines ) == 9
249379 StreamStartChunk .model_validate (lines [0 ])
@@ -373,7 +503,7 @@ async def test_uploads_a_file_to_a_multimodal_model(client: AsyncClient, anon_us
373503
374504 assert_ok_response (response = response )
375505
376- lines = [ json . loads ( line ) for line in response . text . splitlines ()]
506+ lines = _get_dict_lines_from_response ( response )
377507
378508 assert len (lines ) == 9
379509 finished_thread = Thread .model_validate (lines [- 2 ])
0 commit comments