Skip to content

Commit 29a1c4a

Browse files
authored
Ensure ModelResponse fields are set from actual model response when streaming (#2848)
1 parent 89c9bb6 commit 29a1c4a

File tree

16 files changed

+315
-40
lines changed

16 files changed

+315
-40
lines changed

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
66
from contextlib import asynccontextmanager
77
from dataclasses import dataclass, field
8-
from datetime import datetime, timezone
8+
from datetime import datetime
99
from typing import Any, Literal, cast, overload
1010

1111
from typing_extensions import assert_never
@@ -362,13 +362,13 @@ async def _process_streamed_response(
362362
if isinstance(first_chunk, _utils.Unset):
363363
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') # pragma: no cover
364364

365-
# Since Anthropic doesn't provide a timestamp in the message, we'll use the current time
366-
timestamp = datetime.now(tz=timezone.utc)
365+
assert isinstance(first_chunk, BetaRawMessageStartEvent)
366+
367367
return AnthropicStreamedResponse(
368368
model_request_parameters=model_request_parameters,
369-
_model_name=self._model_name,
369+
_model_name=first_chunk.message.model,
370370
_response=peekable_response,
371-
_timestamp=timestamp,
371+
_timestamp=_utils.now_utc(),
372372
_provider_name=self._provider.name,
373373
)
374374

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
BuiltinToolCallPart,
2323
BuiltinToolReturnPart,
2424
DocumentUrl,
25+
FinishReason,
2526
ImageUrl,
2627
ModelMessage,
2728
ModelRequest,
@@ -48,13 +49,15 @@
4849
from botocore.client import BaseClient
4950
from botocore.eventstream import EventStream
5051
from mypy_boto3_bedrock_runtime import BedrockRuntimeClient
52+
from mypy_boto3_bedrock_runtime.literals import StopReasonType
5153
from mypy_boto3_bedrock_runtime.type_defs import (
5254
ContentBlockOutputTypeDef,
5355
ContentBlockUnionTypeDef,
5456
ConverseRequestTypeDef,
5557
ConverseResponseTypeDef,
5658
ConverseStreamMetadataEventTypeDef,
5759
ConverseStreamOutputTypeDef,
60+
ConverseStreamResponseTypeDef,
5861
DocumentBlockTypeDef,
5962
GuardrailConfigurationTypeDef,
6063
ImageBlockTypeDef,
@@ -135,6 +138,15 @@
135138
P = ParamSpec('P')
136139
T = typing.TypeVar('T')
137140

141+
_FINISH_REASON_MAP: dict[StopReasonType, FinishReason] = {
142+
'content_filtered': 'content_filter',
143+
'end_turn': 'stop',
144+
'guardrail_intervened': 'content_filter',
145+
'max_tokens': 'length',
146+
'stop_sequence': 'stop',
147+
'tool_use': 'tool_call',
148+
}
149+
138150

139151
class BedrockModelSettings(ModelSettings, total=False):
140152
"""Settings for Bedrock models.
@@ -270,8 +282,9 @@ async def request_stream(
270282
yield BedrockStreamedResponse(
271283
model_request_parameters=model_request_parameters,
272284
_model_name=self.model_name,
273-
_event_stream=response,
285+
_event_stream=response['stream'],
274286
_provider_name=self._provider.name,
287+
_provider_response_id=response.get('ResponseMetadata', {}).get('RequestId', None),
275288
)
276289

277290
async def _process_response(self, response: ConverseResponseTypeDef) -> ModelResponse:
@@ -301,12 +314,18 @@ async def _process_response(self, response: ConverseResponseTypeDef) -> ModelRes
301314
output_tokens=response['usage']['outputTokens'],
302315
)
303316
response_id = response.get('ResponseMetadata', {}).get('RequestId', None)
317+
raw_finish_reason = response['stopReason']
318+
provider_details = {'finish_reason': raw_finish_reason}
319+
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
320+
304321
return ModelResponse(
305322
parts=items,
306323
usage=u,
307324
model_name=self.model_name,
308325
provider_response_id=response_id,
309326
provider_name=self._provider.name,
327+
finish_reason=finish_reason,
328+
provider_details=provider_details,
310329
)
311330

312331
@overload
@@ -316,7 +335,7 @@ async def _messages_create(
316335
stream: Literal[True],
317336
model_settings: BedrockModelSettings | None,
318337
model_request_parameters: ModelRequestParameters,
319-
) -> EventStream[ConverseStreamOutputTypeDef]:
338+
) -> ConverseStreamResponseTypeDef:
320339
pass
321340

322341
@overload
@@ -335,7 +354,7 @@ async def _messages_create(
335354
stream: bool,
336355
model_settings: BedrockModelSettings | None,
337356
model_request_parameters: ModelRequestParameters,
338-
) -> ConverseResponseTypeDef | EventStream[ConverseStreamOutputTypeDef]:
357+
) -> ConverseResponseTypeDef | ConverseStreamResponseTypeDef:
339358
system_prompt, bedrock_messages = await self._map_messages(messages)
340359
inference_config = self._map_inference_config(model_settings)
341360

@@ -372,7 +391,6 @@ async def _messages_create(
372391

373392
if stream:
374393
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse_stream, **params))
375-
model_response = model_response['stream']
376394
else:
377395
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse, **params))
378396
return model_response
@@ -599,25 +617,30 @@ class BedrockStreamedResponse(StreamedResponse):
599617
_event_stream: EventStream[ConverseStreamOutputTypeDef]
600618
_provider_name: str
601619
_timestamp: datetime = field(default_factory=_utils.now_utc)
620+
_provider_response_id: str | None = None
602621

603-
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
622+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
604623
"""Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
605624
606625
This method should be implemented by subclasses to translate the vendor-specific stream of events into
607626
pydantic_ai-format events.
608627
"""
628+
if self._provider_response_id is not None: # pragma: no cover
629+
self.provider_response_id = self._provider_response_id
630+
609631
chunk: ConverseStreamOutputTypeDef
610632
tool_id: str | None = None
611633
async for chunk in _AsyncIteratorWrapper(self._event_stream):
612634
match chunk:
613635
case {'messageStart': _}:
614636
continue
615-
case {'messageStop': _}:
616-
continue
637+
case {'messageStop': message_stop}:
638+
raw_finish_reason = message_stop['stopReason']
639+
self.provider_details = {'finish_reason': raw_finish_reason}
640+
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
617641
case {'metadata': metadata}:
618642
if 'usage' in metadata: # pragma: no branch
619643
self._usage += self._map_usage(metadata)
620-
continue
621644
case {'contentBlockStart': content_block_start}:
622645
index = content_block_start['contentBlockIndex']
623646
start = content_block_start['start']

pydantic_ai_slim/pydantic_ai/models/cohere.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ..messages import (
1515
BuiltinToolCallPart,
1616
BuiltinToolReturnPart,
17+
FinishReason,
1718
ModelMessage,
1819
ModelRequest,
1920
ModelResponse,
@@ -36,6 +37,7 @@
3637
from cohere import (
3738
AssistantChatMessageV2,
3839
AsyncClientV2,
40+
ChatFinishReason,
3941
ChatMessageV2,
4042
SystemChatMessageV2,
4143
TextAssistantMessageV2ContentItem,
@@ -80,6 +82,14 @@
8082
See [Cohere's docs](https://docs.cohere.com/v2/docs/models) for a list of all available models.
8183
"""
8284

85+
_FINISH_REASON_MAP: dict[ChatFinishReason, FinishReason] = {
86+
'COMPLETE': 'stop',
87+
'STOP_SEQUENCE': 'stop',
88+
'MAX_TOKENS': 'length',
89+
'TOOL_CALL': 'tool_call',
90+
'ERROR': 'error',
91+
}
92+
8393

8494
class CohereModelSettings(ModelSettings, total=False):
8595
"""Settings used for a Cohere model request."""
@@ -205,8 +215,18 @@ def _process_response(self, response: V2ChatResponse) -> ModelResponse:
205215
tool_call_id=c.id or _generate_tool_call_id(),
206216
)
207217
)
218+
219+
raw_finish_reason = response.finish_reason
220+
provider_details = {'finish_reason': raw_finish_reason}
221+
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
222+
208223
return ModelResponse(
209-
parts=parts, usage=_map_usage(response), model_name=self._model_name, provider_name=self._provider.name
224+
parts=parts,
225+
usage=_map_usage(response),
226+
model_name=self._model_name,
227+
provider_name=self._provider.name,
228+
finish_reason=finish_reason,
229+
provider_details=provider_details,
210230
)
211231

212232
def _map_messages(self, messages: list[ModelMessage]) -> list[ChatMessageV2]:

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ async def _process_streamed_response(
453453

454454
return GeminiStreamedResponse(
455455
model_request_parameters=model_request_parameters,
456-
_model_name=self._model_name,
456+
_model_name=first_chunk.model_version or self._model_name,
457457
_response=peekable_response,
458458
_timestamp=first_chunk.create_time or _utils.now_utc(),
459459
_provider_name=self._provider.name,

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
BuiltinToolCallPart,
2424
BuiltinToolReturnPart,
2525
DocumentUrl,
26+
FinishReason,
2627
ImageUrl,
2728
ModelMessage,
2829
ModelRequest,
@@ -100,6 +101,14 @@
100101
See <https://console.groq.com/docs/models> for an up to date date list of models and more details.
101102
"""
102103

104+
_FINISH_REASON_MAP: dict[Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'], FinishReason] = {
105+
'stop': 'stop',
106+
'length': 'length',
107+
'tool_calls': 'tool_call',
108+
'content_filter': 'content_filter',
109+
'function_call': 'tool_call',
110+
}
111+
103112

104113
class GroqModelSettings(ModelSettings, total=False):
105114
"""Settings used for a Groq model request."""
@@ -186,7 +195,13 @@ async def request(
186195
tool_name=error.error.failed_generation.name,
187196
args=error.error.failed_generation.arguments,
188197
)
189-
return ModelResponse(parts=[tool_call_part])
198+
return ModelResponse(
199+
parts=[tool_call_part],
200+
model_name=e.model_name,
201+
timestamp=_utils.now_utc(),
202+
provider_name=self._provider.name,
203+
finish_reason='error',
204+
)
190205
except ValidationError:
191206
pass
192207
raise
@@ -315,13 +330,19 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
315330
if choice.message.tool_calls is not None:
316331
for c in choice.message.tool_calls:
317332
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
333+
334+
raw_finish_reason = choice.finish_reason
335+
provider_details = {'finish_reason': raw_finish_reason}
336+
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
318337
return ModelResponse(
319338
parts=items,
320339
usage=_map_usage(response),
321340
model_name=response.model,
322341
timestamp=timestamp,
323342
provider_response_id=response.id,
324343
provider_name=self._provider.name,
344+
finish_reason=finish_reason,
345+
provider_details=provider_details,
325346
)
326347

327348
async def _process_streamed_response(
@@ -338,7 +359,7 @@ async def _process_streamed_response(
338359
return GroqStreamedResponse(
339360
model_request_parameters=model_request_parameters,
340361
_response=peekable_response,
341-
_model_name=self._model_name,
362+
_model_name=first_chunk.model,
342363
_model_profile=self.profile,
343364
_timestamp=number_to_datetime(first_chunk.created),
344365
_provider_name=self._provider.name,
@@ -497,11 +518,18 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
497518
async for chunk in self._response:
498519
self._usage += _map_usage(chunk)
499520

521+
if chunk.id: # pragma: no branch
522+
self.provider_response_id = chunk.id
523+
500524
try:
501525
choice = chunk.choices[0]
502526
except IndexError:
503527
continue
504528

529+
if raw_finish_reason := choice.finish_reason:
530+
self.provider_details = {'finish_reason': raw_finish_reason}
531+
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
532+
505533
# Handle the text part of the response
506534
content = choice.delta.content
507535
if content is not None:

pydantic_ai_slim/pydantic_ai/models/huggingface.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
BuiltinToolCallPart,
2121
BuiltinToolReturnPart,
2222
DocumentUrl,
23+
FinishReason,
2324
ImageUrl,
2425
ModelMessage,
2526
ModelRequest,
@@ -58,6 +59,7 @@
5859
ChatCompletionOutput,
5960
ChatCompletionOutputMessage,
6061
ChatCompletionStreamOutput,
62+
TextGenerationOutputFinishReason,
6163
)
6264
from huggingface_hub.errors import HfHubHTTPError
6365

@@ -94,6 +96,12 @@
9496
You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
9597
"""
9698

99+
_FINISH_REASON_MAP: dict[TextGenerationOutputFinishReason, FinishReason] = {
100+
'length': 'length',
101+
'eos_token': 'stop',
102+
'stop_sequence': 'stop',
103+
}
104+
97105

98106
class HuggingFaceModelSettings(ModelSettings, total=False):
99107
"""Settings used for a Hugging Face model request."""
@@ -266,13 +274,20 @@ def _process_response(self, response: ChatCompletionOutput) -> ModelResponse:
266274
if tool_calls is not None:
267275
for c in tool_calls:
268276
items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
277+
278+
raw_finish_reason = choice.finish_reason
279+
provider_details = {'finish_reason': raw_finish_reason}
280+
finish_reason = _FINISH_REASON_MAP.get(cast(TextGenerationOutputFinishReason, raw_finish_reason), None)
281+
269282
return ModelResponse(
270283
parts=items,
271284
usage=_map_usage(response),
272285
model_name=response.model,
273286
timestamp=timestamp,
274287
provider_response_id=response.id,
275288
provider_name=self._provider.name,
289+
finish_reason=finish_reason,
290+
provider_details=provider_details,
276291
)
277292

278293
async def _process_streamed_response(
@@ -288,7 +303,7 @@ async def _process_streamed_response(
288303

289304
return HuggingFaceStreamedResponse(
290305
model_request_parameters=model_request_parameters,
291-
_model_name=self._model_name,
306+
_model_name=first_chunk.model,
292307
_model_profile=self.profile,
293308
_response=peekable_response,
294309
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
@@ -445,11 +460,20 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
445460
async for chunk in self._response:
446461
self._usage += _map_usage(chunk)
447462

463+
if chunk.id: # pragma: no branch
464+
self.provider_response_id = chunk.id
465+
448466
try:
449467
choice = chunk.choices[0]
450468
except IndexError:
451469
continue
452470

471+
if raw_finish_reason := choice.finish_reason:
472+
self.provider_details = {'finish_reason': raw_finish_reason}
473+
self.finish_reason = _FINISH_REASON_MAP.get(
474+
cast(TextGenerationOutputFinishReason, raw_finish_reason), None
475+
)
476+
453477
# Handle the text part of the response
454478
content = choice.delta.content
455479
if content is not None:

0 commit comments

Comments
 (0)