Skip to content

Commit 71b2cb7

Browse files
authored
Populate ModelResponse.model_name from responses (#883)
1 parent 21231bf commit 71b2cb7

File tree

10 files changed

+52
-49
lines changed

10 files changed

+52
-49
lines changed

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def _process_response(self, response: AnthropicMessage) -> ModelResponse:
236236
)
237237
)
238238

239-
return ModelResponse(items, model_name=self._model_name)
239+
return ModelResponse(items, model_name=response.model)
240240

241241
async def _process_streamed_response(self, response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
242242
peekable_response = _utils.PeekableAsyncStream(response)

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def _process_response(self, response: _GeminiResponse) -> ModelResponse:
233233
else:
234234
raise UnexpectedModelBehavior('Content field missing from Gemini response', str(response))
235235
parts = response['candidates'][0]['content']['parts']
236-
return _process_response_from_parts(parts, model_name=self._model_name)
236+
return _process_response_from_parts(parts, model_name=response.get('model_version', self._model_name))
237237

238238
async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
239239
"""Process a streamed response, and prepare a streaming response to return."""
@@ -610,6 +610,7 @@ class _GeminiResponse(TypedDict):
610610
# usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response
611611
usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]]
612612
prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
613+
model_version: NotRequired[Annotated[str, pydantic.Field(alias='modelVersion')]]
613614

614615

615616
class _GeminiCandidates(TypedDict):

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
212212
if choice.message.tool_calls is not None:
213213
for c in choice.message.tool_calls:
214214
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
215-
return ModelResponse(items, model_name=self._model_name, timestamp=timestamp)
215+
return ModelResponse(items, model_name=response.model, timestamp=timestamp)
216216

217217
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
218218
"""Process a streamed response, and prepare a streaming response to return."""

pydantic_ai_slim/pydantic_ai/models/mistral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def _process_response(self, response: MistralChatCompletionResponse) -> ModelRes
296296
tool = self._map_mistral_to_pydantic_tool_call(tool_call=tool_call)
297297
parts.append(tool)
298298

299-
return ModelResponse(parts, model_name=self._model_name, timestamp=timestamp)
299+
return ModelResponse(parts, model_name=response.model, timestamp=timestamp)
300300

301301
async def _process_streamed_response(
302302
self,

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
232232
if choice.message.tool_calls is not None:
233233
for c in choice.message.tool_calls:
234234
items.append(ToolCallPart(c.function.name, c.function.arguments, c.id))
235-
return ModelResponse(items, model_name=self._model_name, timestamp=timestamp)
235+
return ModelResponse(items, model_name=response.model, timestamp=timestamp)
236236

237237
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
238238
"""Process a streamed response, and prepare a streaming response to return."""

tests/models/test_anthropic.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def completion_message(content: list[ContentBlock], usage: AnthropicUsage) -> An
112112
return AnthropicMessage(
113113
id='123',
114114
content=content,
115-
model='claude-3-5-haiku-latest',
115+
model='claude-3-5-haiku-123',
116116
role='assistant',
117117
stop_reason='end_turn',
118118
type='message',
@@ -141,13 +141,13 @@ async def test_sync_request_text_response(allow_model_requests: None):
141141
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
142142
ModelResponse(
143143
parts=[TextPart(content='world')],
144-
model_name='claude-3-5-haiku-latest',
144+
model_name='claude-3-5-haiku-123',
145145
timestamp=IsNow(tz=timezone.utc),
146146
),
147147
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
148148
ModelResponse(
149149
parts=[TextPart(content='world')],
150-
model_name='claude-3-5-haiku-latest',
150+
model_name='claude-3-5-haiku-123',
151151
timestamp=IsNow(tz=timezone.utc),
152152
),
153153
]
@@ -190,7 +190,7 @@ async def test_request_structured_response(allow_model_requests: None):
190190
tool_call_id='123',
191191
)
192192
],
193-
model_name='claude-3-5-haiku-latest',
193+
model_name='claude-3-5-haiku-123',
194194
timestamp=IsNow(tz=timezone.utc),
195195
),
196196
ModelRequest(
@@ -252,7 +252,7 @@ async def get_location(loc_name: str) -> str:
252252
tool_call_id='1',
253253
)
254254
],
255-
model_name='claude-3-5-haiku-latest',
255+
model_name='claude-3-5-haiku-123',
256256
timestamp=IsNow(tz=timezone.utc),
257257
),
258258
ModelRequest(
@@ -273,7 +273,7 @@ async def get_location(loc_name: str) -> str:
273273
tool_call_id='2',
274274
)
275275
],
276-
model_name='claude-3-5-haiku-latest',
276+
model_name='claude-3-5-haiku-123',
277277
timestamp=IsNow(tz=timezone.utc),
278278
),
279279
ModelRequest(
@@ -288,7 +288,7 @@ async def get_location(loc_name: str) -> str:
288288
),
289289
ModelResponse(
290290
parts=[TextPart(content='final response')],
291-
model_name='claude-3-5-haiku-latest',
291+
model_name='claude-3-5-haiku-123',
292292
timestamp=IsNow(tz=timezone.utc),
293293
),
294294
]

tests/models/test_gemini.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def gemini_response(content: _GeminiContent, finish_reason: Literal['STOP'] | No
440440
candidate = _GeminiCandidates(content=content, index=0, safety_ratings=[])
441441
if finish_reason: # pragma: no cover
442442
candidate['finish_reason'] = finish_reason
443-
return _GeminiResponse(candidates=[candidate], usage_metadata=example_usage())
443+
return _GeminiResponse(candidates=[candidate], usage_metadata=example_usage(), model_version='gemini-1.5-flash-123')
444444

445445

446446
def example_usage() -> _GeminiUsageMetaData:
@@ -459,7 +459,9 @@ async def test_text_success(get_gemini_client: GetGeminiClient):
459459
[
460460
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
461461
ModelResponse(
462-
parts=[TextPart(content='Hello world')], model_name='gemini-1.5-flash', timestamp=IsNow(tz=timezone.utc)
462+
parts=[TextPart(content='Hello world')],
463+
model_name='gemini-1.5-flash-123',
464+
timestamp=IsNow(tz=timezone.utc),
463465
),
464466
]
465467
)
@@ -472,13 +474,13 @@ async def test_text_success(get_gemini_client: GetGeminiClient):
472474
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
473475
ModelResponse(
474476
parts=[TextPart(content='Hello world')],
475-
model_name='gemini-1.5-flash',
477+
model_name='gemini-1.5-flash-123',
476478
timestamp=IsNow(tz=timezone.utc),
477479
),
478480
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
479481
ModelResponse(
480482
parts=[TextPart(content='Hello world')],
481-
model_name='gemini-1.5-flash',
483+
model_name='gemini-1.5-flash-123',
482484
timestamp=IsNow(tz=timezone.utc),
483485
),
484486
]
@@ -505,7 +507,7 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient):
505507
args={'response': [1, 2, 123]},
506508
)
507509
],
508-
model_name='gemini-1.5-flash',
510+
model_name='gemini-1.5-flash-123',
509511
timestamp=IsNow(tz=timezone.utc),
510512
),
511513
ModelRequest(
@@ -566,7 +568,7 @@ async def get_location(loc_name: str) -> str:
566568
args={'loc_name': 'San Fransisco'},
567569
)
568570
],
569-
model_name='gemini-1.5-flash',
571+
model_name='gemini-1.5-flash-123',
570572
timestamp=IsNow(tz=timezone.utc),
571573
),
572574
ModelRequest(
@@ -589,7 +591,7 @@ async def get_location(loc_name: str) -> str:
589591
args={'loc_name': 'New York'},
590592
),
591593
],
592-
model_name='gemini-1.5-flash',
594+
model_name='gemini-1.5-flash-123',
593595
timestamp=IsNow(tz=timezone.utc),
594596
),
595597
ModelRequest(
@@ -604,7 +606,7 @@ async def get_location(loc_name: str) -> str:
604606
),
605607
ModelResponse(
606608
parts=[TextPart(content='final response')],
607-
model_name='gemini-1.5-flash',
609+
model_name='gemini-1.5-flash-123',
608610
timestamp=IsNow(tz=timezone.utc),
609611
),
610612
]

tests/models/test_groq.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage
103103
id='123',
104104
choices=[Choice(finish_reason='stop', index=0, message=message)],
105105
created=1704067200, # 2024-01-01
106-
model='llama-3.3-70b-versatile',
106+
model='llama-3.3-70b-versatile-123',
107107
object='chat.completion',
108108
usage=usage,
109109
)
@@ -130,13 +130,13 @@ async def test_request_simple_success(allow_model_requests: None):
130130
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
131131
ModelResponse(
132132
parts=[TextPart(content='world')],
133-
model_name='llama-3.3-70b-versatile',
133+
model_name='llama-3.3-70b-versatile-123',
134134
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
135135
),
136136
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
137137
ModelResponse(
138138
parts=[TextPart(content='world')],
139-
model_name='llama-3.3-70b-versatile',
139+
model_name='llama-3.3-70b-versatile-123',
140140
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
141141
),
142142
]
@@ -187,7 +187,7 @@ async def test_request_structured_response(allow_model_requests: None):
187187
tool_call_id='123',
188188
)
189189
],
190-
model_name='llama-3.3-70b-versatile',
190+
model_name='llama-3.3-70b-versatile-123',
191191
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
192192
),
193193
ModelRequest(
@@ -273,7 +273,7 @@ async def get_location(loc_name: str) -> str:
273273
tool_call_id='1',
274274
)
275275
],
276-
model_name='llama-3.3-70b-versatile',
276+
model_name='llama-3.3-70b-versatile-123',
277277
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
278278
),
279279
ModelRequest(
@@ -294,7 +294,7 @@ async def get_location(loc_name: str) -> str:
294294
tool_call_id='2',
295295
)
296296
],
297-
model_name='llama-3.3-70b-versatile',
297+
model_name='llama-3.3-70b-versatile-123',
298298
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
299299
),
300300
ModelRequest(
@@ -309,7 +309,7 @@ async def get_location(loc_name: str) -> str:
309309
),
310310
ModelResponse(
311311
parts=[TextPart(content='final response')],
312-
model_name='llama-3.3-70b-versatile',
312+
model_name='llama-3.3-70b-versatile-123',
313313
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
314314
),
315315
]

tests/models/test_mistral.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def completion_message(
123123
id='123',
124124
choices=[MistralChatCompletionChoice(finish_reason='stop', index=0, message=message)],
125125
created=1704067200 if with_created else None, # 2024-01-01
126-
model='mistral-large-latest',
126+
model='mistral-large-123',
127127
object='chat.completion',
128128
usage=usage or MistralUsageInfo(prompt_tokens=1, completion_tokens=1, total_tokens=1),
129129
)
@@ -217,13 +217,13 @@ async def test_multiple_completions(allow_model_requests: None):
217217
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
218218
ModelResponse(
219219
parts=[TextPart(content='world')],
220-
model_name='mistral-large-latest',
220+
model_name='mistral-large-123',
221221
timestamp=IsNow(tz=timezone.utc),
222222
),
223223
ModelRequest(parts=[UserPromptPart(content='hello again', timestamp=IsNow(tz=timezone.utc))]),
224224
ModelResponse(
225225
parts=[TextPart(content='hello again')],
226-
model_name='mistral-large-latest',
226+
model_name='mistral-large-123',
227227
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
228228
),
229229
]
@@ -269,19 +269,19 @@ async def test_three_completions(allow_model_requests: None):
269269
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
270270
ModelResponse(
271271
parts=[TextPart(content='world')],
272-
model_name='mistral-large-latest',
272+
model_name='mistral-large-123',
273273
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
274274
),
275275
ModelRequest(parts=[UserPromptPart(content='hello again', timestamp=IsNow(tz=timezone.utc))]),
276276
ModelResponse(
277277
parts=[TextPart(content='hello again')],
278-
model_name='mistral-large-latest',
278+
model_name='mistral-large-123',
279279
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
280280
),
281281
ModelRequest(parts=[UserPromptPart(content='final message', timestamp=IsNow(tz=timezone.utc))]),
282282
ModelResponse(
283283
parts=[TextPart(content='final message')],
284-
model_name='mistral-large-latest',
284+
model_name='mistral-large-123',
285285
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
286286
),
287287
]
@@ -396,7 +396,7 @@ class CityLocation(BaseModel):
396396
tool_call_id='123',
397397
)
398398
],
399-
model_name='mistral-large-latest',
399+
model_name='mistral-large-123',
400400
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
401401
),
402402
ModelRequest(
@@ -458,7 +458,7 @@ class CityLocation(BaseModel):
458458
tool_call_id='123',
459459
)
460460
],
461-
model_name='mistral-large-latest',
461+
model_name='mistral-large-123',
462462
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
463463
),
464464
ModelRequest(
@@ -519,7 +519,7 @@ async def test_request_result_type_with_arguments_str_response(allow_model_reque
519519
tool_call_id='123',
520520
)
521521
],
522-
model_name='mistral-large-latest',
522+
model_name='mistral-large-123',
523523
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
524524
),
525525
ModelRequest(
@@ -1104,7 +1104,7 @@ async def get_location(loc_name: str) -> str:
11041104
tool_call_id='1',
11051105
)
11061106
],
1107-
model_name='mistral-large-latest',
1107+
model_name='mistral-large-123',
11081108
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
11091109
),
11101110
ModelRequest(
@@ -1125,7 +1125,7 @@ async def get_location(loc_name: str) -> str:
11251125
tool_call_id='2',
11261126
)
11271127
],
1128-
model_name='mistral-large-latest',
1128+
model_name='mistral-large-123',
11291129
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
11301130
),
11311131
ModelRequest(
@@ -1140,7 +1140,7 @@ async def get_location(loc_name: str) -> str:
11401140
),
11411141
ModelResponse(
11421142
parts=[TextPart(content='final response')],
1143-
model_name='mistral-large-latest',
1143+
model_name='mistral-large-123',
11441144
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
11451145
),
11461146
]
@@ -1244,7 +1244,7 @@ async def get_location(loc_name: str) -> str:
12441244
tool_call_id='1',
12451245
)
12461246
],
1247-
model_name='mistral-large-latest',
1247+
model_name='mistral-large-123',
12481248
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
12491249
),
12501250
ModelRequest(
@@ -1265,7 +1265,7 @@ async def get_location(loc_name: str) -> str:
12651265
tool_call_id='2',
12661266
)
12671267
],
1268-
model_name='mistral-large-latest',
1268+
model_name='mistral-large-123',
12691269
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
12701270
),
12711271
ModelRequest(
@@ -1286,7 +1286,7 @@ async def get_location(loc_name: str) -> str:
12861286
tool_call_id='1',
12871287
)
12881288
],
1289-
model_name='mistral-large-latest',
1289+
model_name='mistral-large-123',
12901290
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
12911291
),
12921292
ModelRequest(

0 commit comments

Comments
 (0)