11"""AWS Bedrock ModelClient integration."""
2-
2+ import json
33import os
4- from typing import Dict , Optional , Any , Callable
4+ from typing import (
5+ Dict ,
6+ Optional ,
7+ Any ,
8+ Callable ,
9+ Generator as GeneratorType
10+ )
511import backoff
612import logging
713
1521from botocore .config import Config
1622
1723log = logging .getLogger (__name__ )
24+ log .level = logging .DEBUG
1825
1926bedrock_runtime_exceptions = boto3 .client (
2027 service_name = "bedrock-runtime" ,
@@ -26,7 +33,6 @@ def get_first_message_content(completion: Dict) -> str:
2633 r"""When we only need the content of the first message.
2734 It is the default parser for chat completion."""
2835 return completion ["output" ]["message" ]["content" ][0 ]["text" ]
29- return completion ["output" ]["message" ]["content" ][0 ]["text" ]
3036
3137
3238__all__ = [
@@ -117,6 +123,7 @@ def __init__(
117123 self ._aws_connection_timeout = aws_connection_timeout
118124 self ._aws_read_timeout = aws_read_timeout
119125
126+ self ._client = None
120127 self .session = None
121128 self .sync_client = self .init_sync_client ()
122129 self .chat_completion_parser = (
@@ -158,16 +165,34 @@ def init_sync_client(self):
158165 def init_async_client (self ):
159166 raise NotImplementedError ("Async call not implemented yet." )
160167
161- def parse_chat_completion (self , completion ):
162- log .debug (f"completion: { completion } " )
168+ @staticmethod
169+ def parse_stream_response (completion : dict ) -> str :
170+ if "contentBlockDelta" in completion :
171+ if delta_chunk := completion ["contentBlockDelta" ]["delta" ]:
172+ return delta_chunk ["text" ]
173+ return ''
174+
175+ def handle_stream_response (self , stream : dict ) -> GeneratorType :
176+ try :
177+ for chunk in stream ["stream" ]:
178+ log .debug (f"Raw chunk: { chunk } " )
179+ parsed_content = self .parse_stream_response (chunk )
180+ yield parsed_content
181+ except Exception as e :
182+ print (f"Error in handle_stream_response: { e } " ) # Debug print
183+ raise
184+
185+ def parse_chat_completion (self , completion : dict ) -> "GeneratorOutput" :
186+ """Parse the completion, and put it into the raw_response."""
163187 try :
164- data = completion ["output" ]["message" ]["content" ][0 ]["text" ]
165- usage = self .track_completion_usage (completion )
166- return GeneratorOutput (data = None , usage = usage , raw_response = data )
188+ data = self .handle_stream_response (completion )
189+ return GeneratorOutput (
190+ data = None , error = None , raw_response = data
191+ )
167192 except Exception as e :
168- log .error (f"Error parsing completion: { e } " )
193+ log .error (f"Error parsing the completion: { e } " )
169194 return GeneratorOutput (
170- data = None , error = str (e ), raw_response = str (completion )
195+ data = None , error = str (e ), raw_response = json . dumps (completion )
171196 )
172197
173198 def track_completion_usage (self , completion : Dict ) -> CompletionUsage :
@@ -184,12 +209,13 @@ def list_models(self):
184209
185210 try :
186211 response = self ._client .list_foundation_models ()
187- models = response .get ("models " , [])
212+ models = response .get ("modelSummaries " , [])
188213 for model in models :
189214 print (f"Model ID: { model ['modelId' ]} " )
190- print (f" Name: { model ['name' ]} " )
191- print (f" Description: { model ['description' ]} " )
192- print (f" Provider: { model ['provider' ]} " )
215+ print (f" Name: { model ['modelName' ]} " )
216+ print (f" Input Modalities: { model ['inputModalities' ]} " )
217+ print (f" Output Modalities: { model ['outputModalities' ]} " )
218+ print (f" Provider: { model ['providerName' ]} " )
193219 print ("" )
194220 except Exception as e :
195221 print (f"Error listing models: { e } " )
@@ -222,14 +248,26 @@ def convert_inputs_to_api_kwargs(
222248 bedrock_runtime_exceptions .ModelErrorException ,
223249 bedrock_runtime_exceptions .ValidationException ,
224250 ),
225- max_time = 5 ,
251+ max_time = 2 ,
226252 )
227- def call (self , api_kwargs : Dict = {}, model_type : ModelType = ModelType .UNDEFINED ):
253+ def call (
254+ self ,
255+ api_kwargs : Dict = {},
256+ model_type : ModelType = ModelType .UNDEFINED ,
257+ stream : bool = False
258+ ) -> dict :
228259 """
229260 kwargs is the combined input and model_kwargs
230261 """
231262 if model_type == ModelType .LLM :
232- return self .sync_client .converse (** api_kwargs )
263+ if "stream" in api_kwargs and api_kwargs .get ("stream" , False ):
264+ log .debug ("Streaming call" )
265+ api_kwargs .pop ("stream" ) # stream is not a valid parameter for bedrock
266+ self .chat_completion_parser = self .handle_stream_response
267+ return self .sync_client .converse_stream (** api_kwargs )
268+ else :
269+ api_kwargs .pop ("stream" )
270+ return self .sync_client .converse (** api_kwargs )
233271 else :
234272 raise ValueError (f"model_type { model_type } is not supported" )
235273
0 commit comments