Skip to content

Commit 3329ccc

Browse files
author
Lloyd Hamilton
committed
Added streaming call for bedrock API
1 parent 3e46f8e commit 3329ccc

File tree

1 file changed

+55
-17
lines changed

1 file changed

+55
-17
lines changed

adalflow/adalflow/components/model_client/bedrock_client.py

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
"""AWS Bedrock ModelClient integration."""
2-
2+
import json
33
import os
4-
from typing import Dict, Optional, Any, Callable
4+
from typing import (
5+
Dict,
6+
Optional,
7+
Any,
8+
Callable,
9+
Generator as GeneratorType
10+
)
511
import backoff
612
import logging
713

@@ -15,6 +21,7 @@
1521
from botocore.config import Config
1622

1723
log = logging.getLogger(__name__)
24+
log.level = logging.DEBUG
1825

1926
bedrock_runtime_exceptions = boto3.client(
2027
service_name="bedrock-runtime",
@@ -26,7 +33,6 @@ def get_first_message_content(completion: Dict) -> str:
2633
r"""When we only need the content of the first message.
2734
It is the default parser for chat completion."""
2835
return completion["output"]["message"]["content"][0]["text"]
29-
return completion["output"]["message"]["content"][0]["text"]
3036

3137

3238
__all__ = [
@@ -117,6 +123,7 @@ def __init__(
117123
self._aws_connection_timeout = aws_connection_timeout
118124
self._aws_read_timeout = aws_read_timeout
119125

126+
self._client = None
120127
self.session = None
121128
self.sync_client = self.init_sync_client()
122129
self.chat_completion_parser = (
@@ -158,16 +165,34 @@ def init_sync_client(self):
158165
def init_async_client(self):
159166
raise NotImplementedError("Async call not implemented yet.")
160167

161-
def parse_chat_completion(self, completion):
162-
log.debug(f"completion: {completion}")
168+
@staticmethod
169+
def parse_stream_response(completion: dict) -> str:
170+
if "contentBlockDelta" in completion:
171+
if delta_chunk := completion["contentBlockDelta"]["delta"]:
172+
return delta_chunk["text"]
173+
return ''
174+
175+
def handle_stream_response(self, stream: dict) -> GeneratorType:
176+
try:
177+
for chunk in stream["stream"]:
178+
log.debug(f"Raw chunk: {chunk}")
179+
parsed_content = self.parse_stream_response(chunk)
180+
yield parsed_content
181+
except Exception as e:
182+
print(f"Error in handle_stream_response: {e}") # Debug print
183+
raise
184+
185+
def parse_chat_completion(self, completion: dict) -> "GeneratorOutput":
186+
"""Parse the completion, and put it into the raw_response."""
163187
try:
164-
data = completion["output"]["message"]["content"][0]["text"]
165-
usage = self.track_completion_usage(completion)
166-
return GeneratorOutput(data=None, usage=usage, raw_response=data)
188+
data = self.handle_stream_response(completion)
189+
return GeneratorOutput(
190+
data=None, error=None, raw_response=data
191+
)
167192
except Exception as e:
168-
log.error(f"Error parsing completion: {e}")
193+
log.error(f"Error parsing the completion: {e}")
169194
return GeneratorOutput(
170-
data=None, error=str(e), raw_response=str(completion)
195+
data=None, error=str(e), raw_response=json.dumps(completion)
171196
)
172197

173198
def track_completion_usage(self, completion: Dict) -> CompletionUsage:
@@ -184,12 +209,13 @@ def list_models(self):
184209

185210
try:
186211
response = self._client.list_foundation_models()
187-
models = response.get("models", [])
212+
models = response.get("modelSummaries", [])
188213
for model in models:
189214
print(f"Model ID: {model['modelId']}")
190-
print(f" Name: {model['name']}")
191-
print(f" Description: {model['description']}")
192-
print(f" Provider: {model['provider']}")
215+
print(f" Name: {model['modelName']}")
216+
print(f" Input Modalities: {model['inputModalities']}")
217+
print(f" Output Modalities: {model['outputModalities']}")
218+
print(f" Provider: {model['providerName']}")
193219
print("")
194220
except Exception as e:
195221
print(f"Error listing models: {e}")
@@ -222,14 +248,26 @@ def convert_inputs_to_api_kwargs(
222248
bedrock_runtime_exceptions.ModelErrorException,
223249
bedrock_runtime_exceptions.ValidationException,
224250
),
225-
max_time=5,
251+
max_time=2,
226252
)
227-
def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED):
253+
def call(
254+
self,
255+
api_kwargs: Dict = {},
256+
model_type: ModelType = ModelType.UNDEFINED,
257+
stream: bool = False
258+
) -> dict:
228259
"""
229260
kwargs is the combined input and model_kwargs
230261
"""
231262
if model_type == ModelType.LLM:
232-
return self.sync_client.converse(**api_kwargs)
263+
if "stream" in api_kwargs and api_kwargs.get("stream", False):
264+
log.debug("Streaming call")
265+
api_kwargs.pop("stream") # stream is not a valid parameter for bedrock
266+
self.chat_completion_parser = self.handle_stream_response
267+
return self.sync_client.converse_stream(**api_kwargs)
268+
else:
269+
api_kwargs.pop("stream")
270+
return self.sync_client.converse(**api_kwargs)
233271
else:
234272
raise ValueError(f"model_type {model_type} is not supported")
235273

0 commit comments

Comments
 (0)