Skip to content

Commit 6cd7303

Browse files
authored
v5 chat - Save thread errors (#559)
Closes allenai/playground-issues-repo#1028 Now that we've released this it's time for me to regret the decision to use P-AI's UI event streaming handling! I had to hack around it to get error saving working. P-AI doesn't return errors in the run output handling or if you use something like [`capture_run_messages`](https://ai.pydantic.dev/agent/#model-errors). To get it, I'm gathering relevant chunks while streaming by yielding through`stream_chat_message`. Since that could cause issues with throwing exceptions properly I split `stream_chat_message` into `initialize_stream_adapter` and `stream_chat_message`. Since those chunks don't have the full response we still need to map from the Pydantic chunks. Since we have the errors as well now, we can map those over by their message ID. Image of me loading a thread that had an error: <img width="819" height="389" alt="image" src="https://github.com/user-attachments/assets/cbb8d468-fac6-4f91-94b4-55cf7071e189" /> Loading a thread that had a failing tool call: <img width="753" height="282" alt="image" src="https://github.com/user-attachments/assets/45ab4257-707b-444b-81ef-7d885bce4e9d" />
1 parent adfe04c commit 6cd7303

File tree

16 files changed

+543
-338
lines changed

16 files changed

+543
-338
lines changed

.env.test

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,5 @@ RECAPTCHA_KEY=fake
2323
SAFETY_QUEUE_ENABLED=false
2424
SAFETY_QUEUE_URL=fake
2525
SAFTEY_GCS_UPLOAD_BUCKET=fake
26+
27+
INCLUDE_TEST_MCP_SERVERS=true

.github/actions/set-up-uv/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ runs:
77
- name: Set up Python
88
uses: actions/setup-python@v6
99
with:
10-
python-version: "3.11"
10+
python-version: "3.14"
1111

1212
- name: Install uv
1313
uses: astral-sh/setup-uv@v7

.vscode/launch.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,16 @@
9393
"PYTHONPATH": "${workspaceFolder}",
9494
"ENV": "development"
9595
}
96+
},
97+
{
98+
"name": "Python: Debug Tests",
99+
"type": "debugpy",
100+
"request": "launch",
101+
"program": "${file}",
102+
"purpose": ["debug-test"],
103+
"console": "integratedTerminal",
104+
"justMyCode": false
96105
}
106+
97107
]
98108
}

apps/api/e2e/test_chat.py

Lines changed: 156 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from pathlib import Path
55

66
import pytest
7-
from httpx import AsyncClient
7+
from httpx import AsyncClient, Response
88
from pydantic import ValidationError
99
from sqlalchemy import select
1010
from sqlalchemy.orm import selectinload
@@ -13,12 +13,15 @@
1313
from api.thread.models.thread import Thread
1414
from core.message.message_chunk import (
1515
AddMessageChunk,
16+
ChunkType,
17+
ErrorChunk,
1618
FinalThreadChunk,
1719
StartThreadChunk,
1820
StreamEndChunk,
1921
StreamStartChunk,
2022
ToolCallChunk,
2123
)
24+
from core.message.message_errors import ErrorCode
2225
from core.message.role import Role
2326
from db.models.message import Message
2427
from e2e.conftest import AuthenticatedClient, DatabaseSession, auth_headers_for_user
@@ -31,10 +34,24 @@
3134
IS_CI = os.getenv("CI", "false") == "true"
3235

3336

37+
def _get_dict_lines_from_response(response: Response):
38+
text_lines = response.text.splitlines()
39+
lines = [json.loads(line) for line in text_lines]
40+
41+
return lines
42+
43+
44+
def _get_lines_without_deltas(response: Response):
45+
lines = _get_dict_lines_from_response(response)
46+
lines_without_stream = [line for line in lines if line["type"] != ChunkType.MODEL_RESPONSE.value]
47+
48+
return lines_without_stream
49+
50+
3451
async def test_calls_user_tools(client: AsyncClient, auth_user: AuthenticatedClient, db_session: DatabaseSession):
35-
tool_name = "get_current_weather"
36-
tool_definition = CreateToolDefinition(
37-
name=tool_name,
52+
weather_tool_name = "get_current_weather"
53+
weather_tool_definition = CreateToolDefinition(
54+
name=weather_tool_name,
3855
description="Get the current weather in a given location",
3956
parameters=ParameterDef(
4057
type="object",
@@ -47,7 +64,21 @@ async def test_calls_user_tools(client: AsyncClient, auth_user: AuthenticatedCli
4764
},
4865
),
4966
)
50-
tool_definitions = f"[{tool_definition.model_dump_json()}]"
67+
68+
location_tool_name = "get_user_location"
69+
location_tool_definition = CreateToolDefinition(
70+
name=location_tool_name,
71+
description="Get the user's location",
72+
parameters=ParameterDef(
73+
type="object",
74+
properties={
75+
"city": ParameterDef(type="string", description="The user's city", default={"string_value": "Boston"}),
76+
"state": ParameterDef(type="string", description="The user's state", default={"string_value": "MA"}),
77+
},
78+
),
79+
)
80+
81+
tool_definitions = f"[{weather_tool_definition.model_dump_json()}, {location_tool_definition.model_dump_json()}]"
5182
chat_request = UserChatRequest(
5283
content="test tool calling",
5384
model="test-model",
@@ -60,17 +91,18 @@ async def test_calls_user_tools(client: AsyncClient, auth_user: AuthenticatedCli
6091

6192
assert_ok_response(response=response)
6293

63-
lines = [json.loads(line) for line in response.text.splitlines()]
94+
lines = _get_dict_lines_from_response(response)
6495

65-
assert len(lines) == 6
96+
assert len(lines) == 7
6697
StreamStartChunk.model_validate(lines[0])
6798
starting_thread = StartThreadChunk.model_validate(lines[1])
6899
AddMessageChunk.model_validate(lines[2])
69-
tool_call_chunk = ToolCallChunk.model_validate(lines[3])
100+
weather_tool_call_chunk = ToolCallChunk.model_validate(lines[3])
101+
location_tool_call_chunk = ToolCallChunk.model_validate(lines[4])
70102
finished_thread = FinalThreadChunk.model_validate(lines[-2])
71103
StreamEndChunk.model_validate(lines[-1])
72104

73-
assert tool_call_chunk.tool_name == tool_name
105+
assert weather_tool_call_chunk.tool_name == weather_tool_name
74106
assert len(starting_thread.messages) == 2
75107
assert finished_thread.id == starting_thread.id
76108
assert len(finished_thread.messages) == 3
@@ -79,8 +111,8 @@ async def test_calls_user_tools(client: AsyncClient, auth_user: AuthenticatedCli
79111
assert finished_thread.messages[1].role == Role.User
80112
assert finished_thread.messages[2].role == Role.Assistant
81113
assert finished_thread.messages[2].tool_calls
82-
assert len(finished_thread.messages[2].tool_calls) == 1
83-
assert finished_thread.messages[2].tool_calls[0].tool_name == tool_name
114+
assert len(finished_thread.messages[2].tool_calls) == 2
115+
assert finished_thread.messages[2].tool_calls[0].tool_name == weather_tool_name
84116

85117
async with db_session() as session, session.begin():
86118
message_query = (
@@ -105,7 +137,7 @@ async def test_calls_user_tools(client: AsyncClient, auth_user: AuthenticatedCli
105137
model="test-model",
106138
enable_tool_calling=True,
107139
parent=finished_thread.messages[2].id,
108-
tool_call_id=tool_call_chunk.tool_call_id,
140+
tool_call_id=weather_tool_call_chunk.tool_call_id,
109141
).model_dump(exclude_none=True, exclude_computed_fields=True)
110142

111143
tool_request["toolDefinitions"] = tool_definitions
@@ -117,20 +149,54 @@ async def test_calls_user_tools(client: AsyncClient, auth_user: AuthenticatedCli
117149
lines = [json.loads(line) for line in tool_response.text.splitlines()]
118150

119151
StreamStartChunk.model_validate(lines[0])
120-
tool_response_chunk = AddMessageChunk.model_validate(lines[1])
152+
weather_tool_response_chunk = AddMessageChunk.model_validate(lines[1])
121153
# ...streaming response...
122-
final_thread_chunk = FinalThreadChunk.model_validate(lines[-2])
154+
final_thread_with_pending_tools_chunk = FinalThreadChunk.model_validate(lines[-2])
123155
StreamEndChunk.model_validate(lines[-1])
124156

125-
assert tool_response_chunk.messages[0].tool_calls, "There were no tool calls in the tool result response"
126-
assert tool_response_chunk.messages[0].tool_calls[0].tool_call_id == tool_call_chunk.tool_call_id
127-
assert tool_response_chunk.messages[0].content == "Sunny"
157+
assert weather_tool_response_chunk.messages[0].tool_calls, "There were no tool calls in the tool result response"
158+
assert weather_tool_response_chunk.messages[0].tool_calls[0].tool_call_id == weather_tool_call_chunk.tool_call_id
159+
assert weather_tool_response_chunk.messages[0].content == "Sunny"
160+
161+
assert len(final_thread_with_pending_tools_chunk.messages) == 1
162+
assert final_thread_with_pending_tools_chunk.messages[0].role == Role.ToolResponse
163+
assert final_thread_with_pending_tools_chunk.messages[0].tool_calls
164+
assert len(final_thread_with_pending_tools_chunk.messages[0].tool_calls) == 1
165+
166+
tool_request = ToolResponseChatRequest(
167+
content='{"city": "Boston", "state": "MA"}',
168+
model="test-model",
169+
enable_tool_calling=True,
170+
parent=final_thread_with_pending_tools_chunk.messages[-1].id,
171+
tool_call_id=location_tool_call_chunk.tool_call_id,
172+
).model_dump(exclude_none=True, exclude_computed_fields=True)
173+
174+
tool_request["toolDefinitions"] = tool_definitions
175+
176+
tool_response = await client.post(CHAT_ENDPOINT, data=tool_request, headers=auth_headers_for_user(auth_user))
177+
178+
assert_ok_response(response=tool_response)
179+
180+
lines = [json.loads(line) for line in tool_response.text.splitlines()]
181+
182+
StreamStartChunk.model_validate(lines[0])
183+
location_tool_response_chunk = AddMessageChunk.model_validate(lines[1])
184+
# ...streaming response...
185+
final_thread_with_pending_tools_chunk = FinalThreadChunk.model_validate(lines[-2])
186+
StreamEndChunk.model_validate(lines[-1])
128187

129-
assert len(final_thread_chunk.messages) == 2
130-
assert final_thread_chunk.messages[0].role == Role.ToolResponse
131-
assert final_thread_chunk.messages[1].role == Role.Assistant
132-
assert final_thread_chunk.messages[0].tool_calls
133-
assert len(final_thread_chunk.messages[0].tool_calls) == 1
188+
assert location_tool_response_chunk.messages[0].tool_calls, "There were no tool calls in the tool result response"
189+
assert location_tool_response_chunk.messages[0].tool_calls[0].tool_call_id == location_tool_call_chunk.tool_call_id
190+
assert location_tool_response_chunk.messages[0].content == '{"city": "Boston", "state": "MA"}'
191+
192+
assert len(final_thread_with_pending_tools_chunk.messages) == 2
193+
assert final_thread_with_pending_tools_chunk.messages[0].role == Role.ToolResponse
194+
assert final_thread_with_pending_tools_chunk.messages[0].tool_calls
195+
assert len(final_thread_with_pending_tools_chunk.messages[0].tool_calls) == 1
196+
assert final_thread_with_pending_tools_chunk.messages[1].role == Role.Assistant
197+
assert final_thread_with_pending_tools_chunk.messages[1].content, (
198+
"The final response with all tool calls should have content"
199+
)
134200

135201

136202
async def test_calls_mcp_tools(client: AsyncClient, auth_user: AuthenticatedClient, db_session: DatabaseSession):
@@ -146,7 +212,9 @@ async def test_calls_mcp_tools(client: AsyncClient, auth_user: AuthenticatedClie
146212

147213
assert_ok_response(response=response)
148214

149-
lines = [json.loads(line) for line in response.text.splitlines()]
215+
lines = _get_lines_without_deltas(response)
216+
217+
assert len(lines) == 8
150218

151219
StreamStartChunk.model_validate(lines[0])
152220
starting_thread = StartThreadChunk.model_validate(lines[1])
@@ -190,6 +258,68 @@ async def test_calls_mcp_tools(client: AsyncClient, auth_user: AuthenticatedClie
190258
assert message_in_db.children[0].id == finished_thread.messages[2].id
191259

192260

261+
async def test_calls_a_failing_tool(client: AsyncClient, anon_user: AuthenticatedClient, db_session: DatabaseSession):
262+
tool_name = "always_fails"
263+
chat_request = UserChatRequest(
264+
content="test failing tool calling",
265+
model="test-model",
266+
enable_tool_calling=True,
267+
selected_tools=[tool_name],
268+
).model_dump(exclude_none=True, exclude_computed_fields=True)
269+
270+
response = await client.post(CHAT_ENDPOINT, data=chat_request, headers=auth_headers_for_user(anon_user))
271+
272+
assert_ok_response(response=response)
273+
274+
lines = _get_lines_without_deltas(response)
275+
276+
assert len(lines) == 7
277+
278+
StreamStartChunk.model_validate(lines[0])
279+
starting_thread = StartThreadChunk.model_validate(lines[1])
280+
tool_call_chunk = ToolCallChunk.model_validate(lines[3])
281+
error_chunk = ErrorChunk.model_validate(lines[4])
282+
finished_thread = FinalThreadChunk.model_validate(lines[-2])
283+
StreamEndChunk.model_validate(lines[-1])
284+
285+
assert tool_call_chunk.tool_name == tool_name
286+
assert error_chunk.error_code == ErrorCode.TOOL_CALL_ERROR
287+
assert len(starting_thread.messages) == 2
288+
assert finished_thread.id == starting_thread.id
289+
assert len(finished_thread.messages) == 3
290+
291+
assert finished_thread.messages[0].role == Role.System
292+
assert finished_thread.messages[1].role == Role.User
293+
294+
assistant_message = finished_thread.messages[2]
295+
assert assistant_message.role == Role.Assistant
296+
297+
assert assistant_message.tool_calls
298+
assert len(assistant_message.tool_calls) == 1, "There were no tool calls on the intended tool call message"
299+
assert assistant_message.tool_calls[0].tool_name == tool_name
300+
assert assistant_message.error_code == ErrorCode.TOOL_CALL_ERROR
301+
302+
async with db_session() as session, session.begin():
303+
message_query = (
304+
select(Message)
305+
.where(Message.id == finished_thread.messages[1].id)
306+
.options(
307+
selectinload(Message.children),
308+
selectinload(Message.parent_),
309+
)
310+
)
311+
message_in_db_result = await session.scalars(message_query)
312+
message_in_db = message_in_db_result.one()
313+
314+
assert message_in_db.parent_ is not None and message_in_db.parent_.id == finished_thread.messages[0].id, ( # noqa: PT018
315+
"User message did not get its parent set correctly in the DB"
316+
)
317+
assert message_in_db.children
318+
assert message_in_db.children[0].id == finished_thread.messages[2].id
319+
320+
assert message_in_db.children[0].error_code == ErrorCode.TOOL_CALL_ERROR
321+
322+
193323
async def test_does_not_call_tools(client: AsyncClient, anon_user: AuthenticatedClient):
194324
tool_name = "get_current_weather"
195325
tool_definition = CreateToolDefinition(
@@ -219,7 +349,7 @@ async def test_does_not_call_tools(client: AsyncClient, anon_user: Authenticated
219349

220350
assert_ok_response(response=response)
221351

222-
lines = [json.loads(line) for line in response.text.splitlines()]
352+
lines = _get_dict_lines_from_response(response)
223353

224354
for line in lines:
225355
with pytest.raises(ValidationError):
@@ -243,7 +373,7 @@ async def test_makes_a_thread_with_parent(
243373

244374
assert_ok_response(response=response)
245375

246-
lines = [json.loads(line) for line in response.text.splitlines()]
376+
lines = _get_dict_lines_from_response(response)
247377

248378
assert len(lines) == 9
249379
StreamStartChunk.model_validate(lines[0])
@@ -373,7 +503,7 @@ async def test_uploads_a_file_to_a_multimodal_model(client: AsyncClient, anon_us
373503

374504
assert_ok_response(response=response)
375505

376-
lines = [json.loads(line) for line in response.text.splitlines()]
506+
lines = _get_dict_lines_from_response(response)
377507

378508
assert len(lines) == 9
379509
finished_thread = Thread.model_validate(lines[-2])

apps/api/src/api/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ class Settings(BaseSettings):
7676
SAFETY_QUEUE_URL: str = Field(init=False)
7777
SAFTEY_GCS_UPLOAD_BUCKET: str = Field(init=False)
7878

79+
INCLUDE_TEST_MCP_SERVERS: bool = Field(
80+
default=False,
81+
description="Used to enable/disable the fake MCP server in test_utils/fake_mcp_server",
82+
)
83+
7984
model_config = SettingsConfigDict(
8085
extra="ignore",
8186
env_file=(".env", f".env.{environment}", ".env.local", f".env.{environment}.local"),

apps/api/src/api/test_utils/fake_mcp_server.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1-
from pydantic_ai import FunctionToolset
1+
from typing import NoReturn
22

3-
test_toolset = FunctionToolset()
3+
from pydantic_ai import FunctionToolset, ModelRetry, RunContext
4+
5+
test_toolset = FunctionToolset(max_retries=0)
6+
7+
8+
@test_toolset.tool()
9+
async def always_fails(ctx: RunContext) -> NoReturn: # noqa: ARG001, RUF029
10+
raise ModelRetry("Always fails") # noqa: EM101, TRY003
411

512

613
@test_toolset.tool()
@@ -13,6 +20,7 @@ async def celsius_to_fahrenheit(celsius: float) -> float: # noqa: RUF029
1320
Returns:
1421
Temperature in Fahrenheit
1522
"""
23+
1624
return (celsius * 9 / 5) + 32
1725

1826

0 commit comments

Comments
 (0)