22
22
BuiltinToolCallPart ,
23
23
BuiltinToolReturnPart ,
24
24
DocumentUrl ,
25
+ FinishReason ,
25
26
ImageUrl ,
26
27
ModelMessage ,
27
28
ModelRequest ,
48
49
from botocore .client import BaseClient
49
50
from botocore .eventstream import EventStream
50
51
from mypy_boto3_bedrock_runtime import BedrockRuntimeClient
52
+ from mypy_boto3_bedrock_runtime .literals import StopReasonType
51
53
from mypy_boto3_bedrock_runtime .type_defs import (
52
54
ContentBlockOutputTypeDef ,
53
55
ContentBlockUnionTypeDef ,
54
56
ConverseRequestTypeDef ,
55
57
ConverseResponseTypeDef ,
56
58
ConverseStreamMetadataEventTypeDef ,
57
59
ConverseStreamOutputTypeDef ,
60
+ ConverseStreamResponseTypeDef ,
58
61
DocumentBlockTypeDef ,
59
62
GuardrailConfigurationTypeDef ,
60
63
ImageBlockTypeDef ,
135
138
P = ParamSpec ('P' )
136
139
T = typing .TypeVar ('T' )
137
140
141
+ _FINISH_REASON_MAP : dict [StopReasonType , FinishReason ] = {
142
+ 'content_filtered' : 'content_filter' ,
143
+ 'end_turn' : 'stop' ,
144
+ 'guardrail_intervened' : 'content_filter' ,
145
+ 'max_tokens' : 'length' ,
146
+ 'stop_sequence' : 'stop' ,
147
+ 'tool_use' : 'tool_call' ,
148
+ }
149
+
138
150
139
151
class BedrockModelSettings (ModelSettings , total = False ):
140
152
"""Settings for Bedrock models.
@@ -270,8 +282,9 @@ async def request_stream(
270
282
yield BedrockStreamedResponse (
271
283
model_request_parameters = model_request_parameters ,
272
284
_model_name = self .model_name ,
273
- _event_stream = response ,
285
+ _event_stream = response [ 'stream' ] ,
274
286
_provider_name = self ._provider .name ,
287
+ _provider_response_id = response .get ('ResponseMetadata' , {}).get ('RequestId' , None ),
275
288
)
276
289
277
290
async def _process_response (self , response : ConverseResponseTypeDef ) -> ModelResponse :
@@ -301,12 +314,18 @@ async def _process_response(self, response: ConverseResponseTypeDef) -> ModelRes
301
314
output_tokens = response ['usage' ]['outputTokens' ],
302
315
)
303
316
response_id = response .get ('ResponseMetadata' , {}).get ('RequestId' , None )
317
+ raw_finish_reason = response ['stopReason' ]
318
+ provider_details = {'finish_reason' : raw_finish_reason }
319
+ finish_reason = _FINISH_REASON_MAP .get (raw_finish_reason )
320
+
304
321
return ModelResponse (
305
322
parts = items ,
306
323
usage = u ,
307
324
model_name = self .model_name ,
308
325
provider_response_id = response_id ,
309
326
provider_name = self ._provider .name ,
327
+ finish_reason = finish_reason ,
328
+ provider_details = provider_details ,
310
329
)
311
330
312
331
@overload
@@ -316,7 +335,7 @@ async def _messages_create(
316
335
stream : Literal [True ],
317
336
model_settings : BedrockModelSettings | None ,
318
337
model_request_parameters : ModelRequestParameters ,
319
- ) -> EventStream [ ConverseStreamOutputTypeDef ] :
338
+ ) -> ConverseStreamResponseTypeDef :
320
339
pass
321
340
322
341
@overload
@@ -335,7 +354,7 @@ async def _messages_create(
335
354
stream : bool ,
336
355
model_settings : BedrockModelSettings | None ,
337
356
model_request_parameters : ModelRequestParameters ,
338
- ) -> ConverseResponseTypeDef | EventStream [ ConverseStreamOutputTypeDef ] :
357
+ ) -> ConverseResponseTypeDef | ConverseStreamResponseTypeDef :
339
358
system_prompt , bedrock_messages = await self ._map_messages (messages )
340
359
inference_config = self ._map_inference_config (model_settings )
341
360
@@ -372,7 +391,6 @@ async def _messages_create(
372
391
373
392
if stream :
374
393
model_response = await anyio .to_thread .run_sync (functools .partial (self .client .converse_stream , ** params ))
375
- model_response = model_response ['stream' ]
376
394
else :
377
395
model_response = await anyio .to_thread .run_sync (functools .partial (self .client .converse , ** params ))
378
396
return model_response
@@ -599,25 +617,30 @@ class BedrockStreamedResponse(StreamedResponse):
599
617
_event_stream : EventStream [ConverseStreamOutputTypeDef ]
600
618
_provider_name : str
601
619
_timestamp : datetime = field (default_factory = _utils .now_utc )
620
+ _provider_response_id : str | None = None
602
621
603
- async def _get_event_iterator (self ) -> AsyncIterator [ModelResponseStreamEvent ]:
622
+ async def _get_event_iterator (self ) -> AsyncIterator [ModelResponseStreamEvent ]: # noqa: C901
604
623
"""Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
605
624
606
625
This method should be implemented by subclasses to translate the vendor-specific stream of events into
607
626
pydantic_ai-format events.
608
627
"""
628
+ if self ._provider_response_id is not None : # pragma: no cover
629
+ self .provider_response_id = self ._provider_response_id
630
+
609
631
chunk : ConverseStreamOutputTypeDef
610
632
tool_id : str | None = None
611
633
async for chunk in _AsyncIteratorWrapper (self ._event_stream ):
612
634
match chunk :
613
635
case {'messageStart' : _}:
614
636
continue
615
- case {'messageStop' : _}:
616
- continue
637
+ case {'messageStop' : message_stop }:
638
+ raw_finish_reason = message_stop ['stopReason' ]
639
+ self .provider_details = {'finish_reason' : raw_finish_reason }
640
+ self .finish_reason = _FINISH_REASON_MAP .get (raw_finish_reason )
617
641
case {'metadata' : metadata }:
618
642
if 'usage' in metadata : # pragma: no branch
619
643
self ._usage += self ._map_usage (metadata )
620
- continue
621
644
case {'contentBlockStart' : content_block_start }:
622
645
index = content_block_start ['contentBlockIndex' ]
623
646
start = content_block_start ['start' ]
0 commit comments