11"""AWS Bedrock ModelClient integration."""
22
3+ import json
34import os
4- from typing import Dict , Optional , Any , Callable
5+ from typing import Dict , Optional , Any , Callable , Generator as GeneratorType
56import backoff
67import logging
78
@@ -26,7 +27,6 @@ def get_first_message_content(completion: Dict) -> str:
2627 r"""When we only need the content of the first message.
2728 It is the default parser for chat completion."""
2829 return completion ["output" ]["message" ]["content" ][0 ]["text" ]
29- return completion ["output" ]["message" ]["content" ][0 ]["text" ]
3030
3131
3232__all__ = [
@@ -117,6 +117,7 @@ def __init__(
117117 self ._aws_connection_timeout = aws_connection_timeout
118118 self ._aws_read_timeout = aws_read_timeout
119119
120+ self ._client = None
120121 self .session = None
121122 self .sync_client = self .init_sync_client ()
122123 self .chat_completion_parser = (
@@ -158,16 +159,51 @@ def init_sync_client(self):
158159 def init_async_client (self ):
159160 raise NotImplementedError ("Async call not implemented yet." )
160161
161- def parse_chat_completion (self , completion ):
162- log .debug (f"completion: { completion } " )
162+ def handle_stream_response (self , stream : dict ) -> GeneratorType :
163+ r"""Handle the stream response from bedrock. Yield the chunks.
164+
165+ Args:
166+ stream (dict): The stream response generator from bedrock.
167+
168+ Returns:
169+ GeneratorType: A generator that yields the chunks from bedrock stream.
170+ """
171+ try :
172+ stream : GeneratorType = stream ["stream" ]
173+ for chunk in stream :
174+ log .debug (f"Raw chunk: { chunk } " )
175+ yield chunk
176+ except Exception as e :
177+ log .debug (f"Error in handle_stream_response: { e } " ) # Debug print
178+ raise
179+
180+ def parse_chat_completion (self , completion : dict ) -> "GeneratorOutput" :
181+ r"""Parse the completion, and assign it into the raw_response attribute.
182+
183+ If the completion is a stream, it will be handled by the handle_stream_response
184+ method that returns a Generator. Otherwise, the completion will be parsed using
185+ the get_first_message_content method.
186+
187+ Args:
188+ completion (dict): The completion response from bedrock API call.
189+
190+ Returns:
191+ GeneratorOutput: A generator output object with the parsed completion. May
192+ return a generator if the completion is a stream.
193+ """
163194 try :
164- data = completion ["output" ]["message" ]["content" ][0 ]["text" ]
165- usage = self .track_completion_usage (completion )
166- return GeneratorOutput (data = None , usage = usage , raw_response = data )
195+ usage = None
196+ data = self .chat_completion_parser (completion )
197+ if not isinstance (data , GeneratorType ):
198+ # Streaming completion usage tracking is not implemented.
199+ usage = self .track_completion_usage (completion )
200+ return GeneratorOutput (
201+ data = None , error = None , raw_response = data , usage = usage
202+ )
167203 except Exception as e :
168- log .error (f"Error parsing completion: { e } " )
204+ log .error (f"Error parsing the completion: { e } " )
169205 return GeneratorOutput (
170- data = None , error = str (e ), raw_response = str (completion )
206+ data = None , error = str (e ), raw_response = json . dumps (completion )
171207 )
172208
173209 def track_completion_usage (self , completion : Dict ) -> CompletionUsage :
@@ -191,6 +227,7 @@ def list_models(self):
191227 print (f" Description: { model ['description' ]} " )
192228 print (f" Provider: { model ['provider' ]} " )
193229 print ("" )
230+
194231 except Exception as e :
195232 print (f"Error listing models: { e } " )
196233
@@ -222,14 +259,27 @@ def convert_inputs_to_api_kwargs(
222259 bedrock_runtime_exceptions .ModelErrorException ,
223260 bedrock_runtime_exceptions .ValidationException ,
224261 ),
225- max_time = 5 ,
262+ max_time = 2 ,
226263 )
227- def call (self , api_kwargs : Dict = {}, model_type : ModelType = ModelType .UNDEFINED ):
264+ def call (
265+ self ,
266+ api_kwargs : Dict = {},
267+ model_type : ModelType = ModelType .UNDEFINED ,
268+ ) -> dict :
228269 """
229270 kwargs is the combined input and model_kwargs
230271 """
231272 if model_type == ModelType .LLM :
232- return self .sync_client .converse (** api_kwargs )
273+ if "stream" in api_kwargs and api_kwargs .get ("stream" , False ):
274+ log .debug ("Streaming call" )
275+ api_kwargs .pop (
276+ "stream" , None
277+ ) # stream is not a valid parameter for bedrock
278+ self .chat_completion_parser = self .handle_stream_response
279+ return self .sync_client .converse_stream (** api_kwargs )
280+ else :
281+ api_kwargs .pop ("stream" , None )
282+ return self .sync_client .converse (** api_kwargs )
233283 else :
234284 raise ValueError (f"model_type { model_type } is not supported" )
235285
0 commit comments