Skip to content

Commit e434af9

Browse files
Add model_name to ModelResponse (#701)
Co-authored-by: David Montague <[email protected]>
1 parent 331ea8d commit e434af9

28 files changed

+394
-102
lines changed

docs/agents.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,7 @@ with capture_run_messages() as messages: # (2)!
445445
part_kind='tool-call',
446446
)
447447
],
448+
model_name='function:model_logic',
448449
timestamp=datetime.datetime(...),
449450
kind='response',
450451
),
@@ -469,6 +470,7 @@ with capture_run_messages() as messages: # (2)!
469470
part_kind='tool-call',
470471
)
471472
],
473+
model_name='function:model_logic',
472474
timestamp=datetime.datetime(...),
473475
kind='response',
474476
),

docs/message-history.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ print(result.all_messages())
6262
part_kind='text',
6363
)
6464
],
65+
model_name='function:model_logic',
6566
timestamp=datetime.datetime(...),
6667
kind='response',
6768
),
@@ -135,6 +136,7 @@ async def main():
135136
part_kind='text',
136137
)
137138
],
139+
model_name='function:stream_model_logic',
138140
timestamp=datetime.datetime(...),
139141
kind='response',
140142
),
@@ -191,6 +193,7 @@ print(result2.all_messages())
191193
part_kind='text',
192194
)
193195
],
196+
model_name='function:model_logic',
194197
timestamp=datetime.datetime(...),
195198
kind='response',
196199
),
@@ -211,6 +214,7 @@ print(result2.all_messages())
211214
part_kind='text',
212215
)
213216
],
217+
model_name='function:model_logic',
214218
timestamp=datetime.datetime(...),
215219
kind='response',
216220
),
@@ -265,6 +269,7 @@ print(result2.all_messages())
265269
part_kind='text',
266270
)
267271
],
272+
model_name='function:model_logic',
268273
timestamp=datetime.datetime(...),
269274
kind='response',
270275
),
@@ -285,6 +290,7 @@ print(result2.all_messages())
285290
part_kind='text',
286291
)
287292
],
293+
model_name='function:model_logic',
288294
timestamp=datetime.datetime(...),
289295
kind='response',
290296
),

docs/testing-evals.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ async def test_forecast():
151151
tool_call_id=None,
152152
)
153153
],
154+
model_name='test',
154155
timestamp=IsNow(tz=timezone.utc),
155156
),
156157
ModelRequest(
@@ -169,6 +170,7 @@ async def test_forecast():
169170
content='{"weather_forecast":"Sunny with a chance of rain"}',
170171
)
171172
],
173+
model_name='test',
172174
timestamp=IsNow(tz=timezone.utc),
173175
),
174176
]

docs/tools.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ print(dice_result.all_messages())
9292
part_kind='tool-call',
9393
)
9494
],
95+
model_name='function:model_logic',
9596
timestamp=datetime.datetime(...),
9697
kind='response',
9798
),
@@ -116,6 +117,7 @@ print(dice_result.all_messages())
116117
part_kind='tool-call',
117118
)
118119
],
120+
model_name='function:model_logic',
119121
timestamp=datetime.datetime(...),
120122
kind='response',
121123
),
@@ -138,6 +140,7 @@ print(dice_result.all_messages())
138140
part_kind='text',
139141
)
140142
],
143+
model_name='function:model_logic',
141144
timestamp=datetime.datetime(...),
142145
kind='response',
143146
),

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,9 @@ class ModelResponse:
252252
parts: list[ModelResponsePart]
253253
"""The parts of the model message."""
254254

255+
model_name: str | None = None
256+
"""The name of the model that generated the response."""
257+
255258
timestamp: datetime = field(default_factory=_now_utc)
256259
"""The timestamp of the response.
257260
@@ -262,14 +265,14 @@ class ModelResponse:
262265
"""Message type identifier, this is available on all parts as a discriminator."""
263266

264267
@classmethod
265-
def from_text(cls, content: str, timestamp: datetime | None = None) -> Self:
268+
def from_text(cls, content: str, model_name: str | None = None, timestamp: datetime | None = None) -> Self:
266269
"""Create a `ModelResponse` containing a single `TextPart`."""
267-
return cls([TextPart(content=content)], timestamp=timestamp or _now_utc())
270+
return cls([TextPart(content=content)], model_name=model_name, timestamp=timestamp or _now_utc())
268271

269272
@classmethod
270-
def from_tool_call(cls, tool_call: ToolCallPart) -> Self:
273+
def from_tool_call(cls, tool_call: ToolCallPart, model_name: str | None = None) -> Self:
271274
"""Create a `ModelResponse` containing a single `ToolCallPart`."""
272-
return cls([tool_call])
275+
return cls([tool_call], model_name=model_name)
273276

274277

275278
ModelMessage = Annotated[Union[ModelRequest, ModelResponse], pydantic.Discriminator('kind')]

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ async def request_stream(
161161
class StreamedResponse(ABC):
162162
"""Streamed response from an LLM when calling a tool."""
163163

164+
_model_name: str
164165
_usage: Usage = field(default_factory=Usage, init=False)
165166
_parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
166167
_event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
@@ -184,7 +185,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
184185

185186
def get(self) -> ModelResponse:
186187
"""Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
187-
return ModelResponse(parts=self._parts_manager.get_parts(), timestamp=self.timestamp())
188+
return ModelResponse(
189+
parts=self._parts_manager.get_parts(), model_name=self._model_name, timestamp=self.timestamp()
190+
)
191+
192+
def model_name(self) -> str:
193+
"""Get the model name of the response."""
194+
return self._model_name
188195

189196
def usage(self) -> Usage:
190197
"""Get the usage of the response so far. This will not be the final usage until the stream is exhausted."""

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ class AnthropicAgentModel(AgentModel):
152152
"""Implementation of `AgentModel` for Anthropic models."""
153153

154154
client: AsyncAnthropic
155-
model_name: str
155+
model_name: AnthropicModelName
156156
allow_text_result: bool
157157
tools: list[ToolParam]
158158

@@ -210,8 +210,7 @@ async def _messages_create(
210210
timeout=model_settings.get('timeout', NOT_GIVEN),
211211
)
212212

213-
@staticmethod
214-
def _process_response(response: AnthropicMessage) -> ModelResponse:
213+
def _process_response(self, response: AnthropicMessage) -> ModelResponse:
215214
"""Process a non-streamed response, and prepare a message to return."""
216215
items: list[ModelResponsePart] = []
217216
for item in response.content:
@@ -227,7 +226,7 @@ def _process_response(response: AnthropicMessage) -> ModelResponse:
227226
)
228227
)
229228

230-
return ModelResponse(items)
229+
return ModelResponse(items, model_name=self.model_name)
231230

232231
@staticmethod
233232
async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:

pydantic_ai_slim/pydantic_ai/models/cohere.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,7 @@ async def _chat(
172172
p=model_settings.get('top_p', OMIT),
173173
)
174174

175-
@staticmethod
176-
def _process_response(response: ChatResponse) -> ModelResponse:
175+
def _process_response(self, response: ChatResponse) -> ModelResponse:
177176
"""Process a non-streamed response, and prepare a message to return."""
178177
parts: list[ModelResponsePart] = []
179178
if response.message.content is not None and len(response.message.content) > 0:
@@ -190,7 +189,7 @@ def _process_response(response: ChatResponse) -> ModelResponse:
190189
tool_call_id=c.id,
191190
)
192191
)
193-
return ModelResponse(parts=parts)
192+
return ModelResponse(parts=parts, model_name=self.model_name)
194193

195194
@classmethod
196195
def _map_message(cls, message: ModelMessage) -> Iterable[ChatMessageV2]:

pydantic_ai_slim/pydantic_ai/models/function.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,15 @@ async def agent_model(
7171
result_tools: list[ToolDefinition],
7272
) -> AgentModel:
7373
return FunctionAgentModel(
74-
self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools, None)
74+
self.function,
75+
self.stream_function,
76+
AgentInfo(function_tools, allow_text_result, result_tools, None),
7577
)
7678

7779
def name(self) -> str:
78-
labels: list[str] = []
79-
if self.function is not None:
80-
labels.append(self.function.__name__)
81-
if self.stream_function is not None:
82-
labels.append(f'stream-{self.stream_function.__name__}')
83-
return f'function:{",".join(labels)}'
80+
function_name = self.function.__name__ if self.function is not None else ''
81+
stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
82+
return f'function:{function_name}:{stream_function_name}'
8483

8584

8685
@dataclass(frozen=True)
@@ -147,12 +146,15 @@ async def request(
147146
agent_info = replace(self.agent_info, model_settings=model_settings)
148147

149148
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
149+
model_name = f'function:{self.function.__name__}'
150+
150151
if inspect.iscoroutinefunction(self.function):
151152
response = await self.function(messages, agent_info)
152153
else:
153154
response_ = await _utils.run_in_executor(self.function, messages, agent_info)
154155
assert isinstance(response_, ModelResponse), response_
155156
response = response_
157+
response.model_name = model_name
156158
# TODO is `messages` right here? Should it just be new messages?
157159
return response, _estimate_usage(chain(messages, [response]))
158160

@@ -163,13 +165,15 @@ async def request_stream(
163165
assert (
164166
self.stream_function is not None
165167
), 'FunctionModel must receive a `stream_function` to support streamed requests'
168+
model_name = f'function:{self.stream_function.__name__}'
169+
166170
response_stream = PeekableAsyncStream(self.stream_function(messages, self.agent_info))
167171

168172
first = await response_stream.peek()
169173
if isinstance(first, _utils.Unset):
170174
raise ValueError('Stream function must return at least one item')
171175

172-
yield FunctionStreamedResponse(response_stream)
176+
yield FunctionStreamedResponse(_model_name=model_name, _iter=response_stream)
173177

174178

175179
@dataclass

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -229,15 +229,13 @@ async def _make_request(
229229
raise exceptions.UnexpectedModelBehavior(f'Unexpected response from gemini {r.status_code}', r.text)
230230
yield r
231231

232-
@staticmethod
233-
def _process_response(response: _GeminiResponse) -> ModelResponse:
232+
def _process_response(self, response: _GeminiResponse) -> ModelResponse:
234233
if len(response['candidates']) != 1:
235234
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
236235
parts = response['candidates'][0]['content']['parts']
237-
return _process_response_from_parts(parts)
236+
return _process_response_from_parts(parts, model_name=self.model_name)
238237

239-
@staticmethod
240-
async def _process_streamed_response(http_response: HTTPResponse) -> StreamedResponse:
238+
async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
241239
"""Process a streamed response, and prepare a streaming response to return."""
242240
aiter_bytes = http_response.aiter_bytes()
243241
start_response: _GeminiResponse | None = None
@@ -258,7 +256,7 @@ async def _process_streamed_response(http_response: HTTPResponse) -> StreamedRes
258256
if start_response is None:
259257
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
260258

261-
return GeminiStreamedResponse(_content=content, _stream=aiter_bytes)
259+
return GeminiStreamedResponse(_model_name=self.model_name, _content=content, _stream=aiter_bytes)
262260

263261
@classmethod
264262
def _message_to_gemini_content(
@@ -432,7 +430,9 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart
432430
return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args_as_dict()))
433431

434432

435-
def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: datetime | None = None) -> ModelResponse:
433+
def _process_response_from_parts(
434+
parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, timestamp: datetime | None = None
435+
) -> ModelResponse:
436436
items: list[ModelResponsePart] = []
437437
for part in parts:
438438
if 'text' in part:
@@ -448,7 +448,7 @@ def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: d
448448
raise exceptions.UnexpectedModelBehavior(
449449
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
450450
)
451-
return ModelResponse(items, timestamp=timestamp or _utils.now_utc())
451+
return ModelResponse(parts=items, model_name=model_name, timestamp=timestamp or _utils.now_utc())
452452

453453

454454
class _GeminiFunctionCall(TypedDict):

0 commit comments

Comments
 (0)