1818
1919from __future__ import annotations
2020
21+ import json
22+
2123from botocore .eventstream import EventStream
2224from wrapt import ObjectProxy
2325
@@ -46,20 +48,21 @@ def __iter__(self):
4648 def _process_event (self , event ):
4749 if "messageStart" in event :
4850 # {'messageStart': {'role': 'assistant'}}
49- pass
51+ return
5052
5153 if "contentBlockDelta" in event :
5254 # {'contentBlockDelta': {'delta': {'text': "Hello"}, 'contentBlockIndex': 0}}
53- pass
55+ return
5456
5557 if "contentBlockStop" in event :
5658 # {'contentBlockStop': {'contentBlockIndex': 0}}
57- pass
59+ return
5860
5961 if "messageStop" in event :
6062 # {'messageStop': {'stopReason': 'end_turn'}}
6163 if stop_reason := event ["messageStop" ].get ("stopReason" ):
6264 self ._response ["stopReason" ] = stop_reason
65+ return
6366
6467 if "metadata" in event :
6568 # {'metadata': {'usage': {'inputTokens': 12, 'outputTokens': 15, 'totalTokens': 27}, 'metrics': {'latencyMs': 2980}}}
@@ -72,3 +75,136 @@ def _process_event(self, event):
7275 self ._response ["usage" ]["outputTokens" ] = output_tokens
7376
7477 self ._stream_done_callback (self ._response )
78+ return
79+
80+
81+ # pylint: disable=abstract-method
82+ class InvokeModelWithResponseStreamWrapper (ObjectProxy ):
83+ """Wrapper for botocore.eventstream.EventStream"""
84+
85+ def __init__ (
86+ self ,
87+ stream : EventStream ,
88+ stream_done_callback ,
89+ model_id : str ,
90+ ):
91+ super ().__init__ (stream )
92+
93+ self ._stream_done_callback = stream_done_callback
94+ self ._model_id = model_id
95+
96+ # accumulating things in the same shape of the Converse API
97+ # {"usage": {"inputTokens": 0, "outputTokens": 0}, "stopReason": "finish"}
98+ self ._response = {}
99+
100+ def __iter__ (self ):
101+ for event in self .__wrapped__ :
102+ self ._process_event (event )
103+ yield event
104+
105+ def _process_event (self , event ):
106+ if "chunk" not in event :
107+ return
108+
109+ json_bytes = event ["chunk" ].get ("bytes" , b"" )
110+ decoded = json_bytes .decode ("utf-8" )
111+ try :
112+ chunk = json .loads (decoded )
113+ except json .JSONDecodeError :
114+ return
115+
116+ if "amazon.titan" in self ._model_id :
117+ self ._process_amazon_titan_chunk (chunk )
118+ elif "amazon.nova" in self ._model_id :
119+ self ._process_amazon_nova_chunk (chunk )
120+ elif "anthropic.claude" in self ._model_id :
121+ self ._process_anthropic_claude_chunk (chunk )
122+
123+ def _process_invocation_metrics (self , invocation_metrics ):
124+ self ._response ["usage" ] = {}
125+ if input_tokens := invocation_metrics .get ("inputTokenCount" ):
126+ self ._response ["usage" ]["inputTokens" ] = input_tokens
127+
128+ if output_tokens := invocation_metrics .get ("outputTokenCount" ):
129+ self ._response ["usage" ]["outputTokens" ] = output_tokens
130+
131+ def _process_amazon_titan_chunk (self , chunk ):
132+ if (stop_reason := chunk .get ("completionReason" )) is not None :
133+ self ._response ["stopReason" ] = stop_reason
134+
135+ if invocation_metrics := chunk .get ("amazon-bedrock-invocationMetrics" ):
136+ # "amazon-bedrock-invocationMetrics":{
137+ # "inputTokenCount":9,"outputTokenCount":128,"invocationLatency":3569,"firstByteLatency":2180
138+ # }
139+ self ._process_invocation_metrics (invocation_metrics )
140+ self ._stream_done_callback (self ._response )
141+
142+ def _process_amazon_nova_chunk (self , chunk ):
143+ if "messageStart" in chunk :
144+ # {'messageStart': {'role': 'assistant'}}
145+ return
146+
147+ if "contentBlockDelta" in chunk :
148+ # {'contentBlockDelta': {'delta': {'text': "Hello"}, 'contentBlockIndex': 0}}
149+ return
150+
151+ if "contentBlockStop" in chunk :
152+ # {'contentBlockStop': {'contentBlockIndex': 0}}
153+ return
154+
155+ if "messageStop" in chunk :
156+ # {'messageStop': {'stopReason': 'end_turn'}}
157+ if stop_reason := chunk ["messageStop" ].get ("stopReason" ):
158+ self ._response ["stopReason" ] = stop_reason
159+ return
160+
161+ if "metadata" in chunk :
162+ # {'metadata': {'usage': {'inputTokens': 8, 'outputTokens': 117}, 'metrics': {}, 'trace': {}}}
163+ if usage := chunk ["metadata" ].get ("usage" ):
164+ self ._response ["usage" ] = {}
165+ if input_tokens := usage .get ("inputTokens" ):
166+ self ._response ["usage" ]["inputTokens" ] = input_tokens
167+
168+ if output_tokens := usage .get ("outputTokens" ):
169+ self ._response ["usage" ]["outputTokens" ] = output_tokens
170+
171+ self ._stream_done_callback (self ._response )
172+ return
173+
174+ def _process_anthropic_claude_chunk (self , chunk ):
175+ # pylint: disable=too-many-return-statements
176+ if not (message_type := chunk .get ("type" )):
177+ return
178+
179+ if message_type == "message_start" :
180+ # {'type': 'message_start', 'message': {'id': 'id', 'type': 'message', 'role': 'assistant', 'model': 'claude-2.0', 'content': [], 'stop_reason': None, 'stop_sequence': None, 'usage': {'input_tokens': 18, 'output_tokens': 1}}}
181+ return
182+
183+ if message_type == "content_block_start" :
184+ # {'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}}
185+ return
186+
187+ if message_type == "content_block_delta" :
188+ # {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Here'}}
189+ return
190+
191+ if message_type == "content_block_stop" :
192+ # {'type': 'content_block_stop', 'index': 0}
193+ return
194+
195+ if message_type == "message_delta" :
196+ # {'type': 'message_delta', 'delta': {'stop_reason': 'end_turn', 'stop_sequence': None}, 'usage': {'output_tokens': 123}}
197+ if (
198+ stop_reason := chunk .get ("delta" , {}).get ("stop_reason" )
199+ ) is not None :
200+ self ._response ["stopReason" ] = stop_reason
201+ return
202+
203+ if message_type == "message_stop" :
204+ # {'type': 'message_stop', 'amazon-bedrock-invocationMetrics': {'inputTokenCount': 18, 'outputTokenCount': 123, 'invocationLatency': 5250, 'firstByteLatency': 290}}
205+ if invocation_metrics := chunk .get (
206+ "amazon-bedrock-invocationMetrics"
207+ ):
208+ self ._process_invocation_metrics (invocation_metrics )
209+ self ._stream_done_callback (self ._response )
210+ return
0 commit comments