Skip to content

Commit b5b8b89

Browse files
committed
Adding run_id to ModelRequest and ModelResponse, propagating it via GraphState to these ModelMessages
1 parent 5768447 commit b5b8b89

File tree

5 files changed

+40
-0
lines changed

5 files changed

+40
-0
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class GraphAgentState:
9292
usage: _usage.RunUsage = dataclasses.field(default_factory=_usage.RunUsage)
9393
retries: int = 0
9494
run_step: int = 0
95+
run_id: str | None = None
9596

9697
def increment_retries(
9798
self,
@@ -469,6 +470,7 @@ async def _make_request(
469470
async def _prepare_request(
470471
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
471472
) -> tuple[ModelSettings | None, models.ModelRequestParameters, list[_messages.ModelMessage], RunContext[DepsT]]:
473+
self.request.run_id = self.request.run_id or ctx.state.run_id
472474
ctx.state.message_history.append(self.request)
473475

474476
ctx.state.run_step += 1
@@ -510,6 +512,7 @@ def _finish_handling(
510512
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
511513
response: _messages.ModelResponse,
512514
) -> CallToolsNode[DepsT, NodeRunEndT]:
515+
response.run_id = response.run_id or ctx.state.run_id
513516
# Update usage
514517
ctx.state.usage.incr(response.usage)
515518
if ctx.deps.usage_limits: # pragma: no branch

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import dataclasses
44
import inspect
55
import json
6+
import uuid
67
import warnings
78
from asyncio import Lock
89
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence
@@ -572,6 +573,7 @@ async def main():
572573
usage=usage,
573574
retries=0,
574575
run_step=0,
576+
run_id=str(uuid.uuid4()),
575577
)
576578

577579
# Merge model settings in order of precedence: run > agent > model

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,9 @@ class ModelRequest:
947947
kind: Literal['request'] = 'request'
948948
"""Message type identifier, this is available on all parts as a discriminator."""
949949

950+
run_id: str | None = None
951+
"""A unique identifier to identify the run."""
952+
950953
@classmethod
951954
def user_text_prompt(cls, user_prompt: str, *, instructions: str | None = None) -> ModelRequest:
952955
"""Create a `ModelRequest` with a single user prompt as text."""
@@ -1188,6 +1191,9 @@ class ModelResponse:
11881191
finish_reason: FinishReason | None = None
11891192
"""Reason the model finished generating the response, normalized to OpenTelemetry values."""
11901193

1194+
run_id: str | None = None
1195+
"""A unique identifier to identify the run."""
1196+
11911197
@property
11921198
def text(self) -> str | None:
11931199
"""Get the text in the response."""

tests/test_agent.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2506,6 +2506,17 @@ def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
25062506
)
25072507

25082508

2509+
def test_agent_message_history_includes_run_id() -> None:
2510+
agent = Agent(TestModel(custom_output_text='testing run_id'))
2511+
2512+
result = agent.run_sync('Hello')
2513+
history = result.all_messages()
2514+
2515+
run_ids = [message.run_id for message in history]
2516+
assert run_ids == snapshot([IsStr(), IsStr()])
2517+
assert len({*run_ids}) == snapshot(1)
2518+
2519+
25092520
def test_unknown_tool():
25102521
def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
25112522
return ModelResponse(parts=[ToolCallPart('foobar', '{}')])

tests/test_messages.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,24 @@ def test_file_part_serialization_roundtrip():
469469
assert deserialized == messages
470470

471471

472+
def test_model_messages_type_adapter_preserves_run_id():
473+
messages: list[ModelMessage] = [
474+
ModelRequest(
475+
parts=[UserPromptPart(content='Hi there', timestamp=datetime.now(tz=timezone.utc))],
476+
run_id='run-123',
477+
),
478+
ModelResponse(
479+
parts=[TextPart(content='Hello!')],
480+
run_id='run-123',
481+
),
482+
]
483+
484+
serialized = ModelMessagesTypeAdapter.dump_python(messages, mode='python')
485+
deserialized = ModelMessagesTypeAdapter.validate_python(serialized)
486+
487+
assert [message.run_id for message in deserialized] == snapshot(['run-123', 'run-123'])
488+
489+
472490
def test_model_response_convenience_methods():
473491
response = ModelResponse(parts=[])
474492
assert response.text == snapshot(None)

0 commit comments

Comments
 (0)