Skip to content

Commit f2abc34

Browse files
authored
Merge pull request #314 from lloydhamilton/fix/added_bedrock_streaming
feat/added bedrock streaming
2 parents 13021df + a0e2125 commit f2abc34

File tree

4 files changed

+209
-18
lines changed

4 files changed

+209
-18
lines changed

adalflow/adalflow/components/model_client/bedrock_client.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""AWS Bedrock ModelClient integration."""
22

3+
import json
34
import os
4-
from typing import Dict, Optional, Any, Callable
5+
from typing import Dict, Optional, Any, Callable, Generator as GeneratorType
56
import backoff
67
import logging
78

@@ -26,7 +27,6 @@ def get_first_message_content(completion: Dict) -> str:
2627
r"""When we only need the content of the first message.
2728
It is the default parser for chat completion."""
2829
return completion["output"]["message"]["content"][0]["text"]
29-
return completion["output"]["message"]["content"][0]["text"]
3030

3131

3232
__all__ = [
@@ -117,6 +117,7 @@ def __init__(
117117
self._aws_connection_timeout = aws_connection_timeout
118118
self._aws_read_timeout = aws_read_timeout
119119

120+
self._client = None
120121
self.session = None
121122
self.sync_client = self.init_sync_client()
122123
self.chat_completion_parser = (
@@ -158,16 +159,51 @@ def init_sync_client(self):
158159
def init_async_client(self):
159160
raise NotImplementedError("Async call not implemented yet.")
160161

161-
def parse_chat_completion(self, completion):
162-
log.debug(f"completion: {completion}")
162+
def handle_stream_response(self, stream: dict) -> GeneratorType:
163+
r"""Handle the stream response from bedrock. Yield the chunks.
164+
165+
Args:
166+
stream (dict): The stream response generator from bedrock.
167+
168+
Returns:
169+
GeneratorType: A generator that yields the chunks from bedrock stream.
170+
"""
171+
try:
172+
stream: GeneratorType = stream["stream"]
173+
for chunk in stream:
174+
log.debug(f"Raw chunk: {chunk}")
175+
yield chunk
176+
except Exception as e:
177+
log.debug(f"Error in handle_stream_response: {e}") # Debug print
178+
raise
179+
180+
def parse_chat_completion(self, completion: dict) -> "GeneratorOutput":
181+
r"""Parse the completion, and assign it into the raw_response attribute.
182+
183+
If the completion is a stream, it will be handled by the handle_stream_response
184+
method that returns a Generator. Otherwise, the completion will be parsed using
185+
the get_first_message_content method.
186+
187+
Args:
188+
completion (dict): The completion response from bedrock API call.
189+
190+
Returns:
191+
GeneratorOutput: A generator output object with the parsed completion. May
192+
return a generator if the completion is a stream.
193+
"""
163194
try:
164-
data = completion["output"]["message"]["content"][0]["text"]
165-
usage = self.track_completion_usage(completion)
166-
return GeneratorOutput(data=None, usage=usage, raw_response=data)
195+
usage = None
196+
data = self.chat_completion_parser(completion)
197+
if not isinstance(data, GeneratorType):
198+
# Streaming completion usage tracking is not implemented.
199+
usage = self.track_completion_usage(completion)
200+
return GeneratorOutput(
201+
data=None, error=None, raw_response=data, usage=usage
202+
)
167203
except Exception as e:
168-
log.error(f"Error parsing completion: {e}")
204+
log.error(f"Error parsing the completion: {e}")
169205
return GeneratorOutput(
170-
data=None, error=str(e), raw_response=str(completion)
206+
data=None, error=str(e), raw_response=json.dumps(completion)
171207
)
172208

173209
def track_completion_usage(self, completion: Dict) -> CompletionUsage:
@@ -191,6 +227,7 @@ def list_models(self):
191227
print(f" Description: {model['description']}")
192228
print(f" Provider: {model['provider']}")
193229
print("")
230+
194231
except Exception as e:
195232
print(f"Error listing models: {e}")
196233

@@ -222,14 +259,27 @@ def convert_inputs_to_api_kwargs(
222259
bedrock_runtime_exceptions.ModelErrorException,
223260
bedrock_runtime_exceptions.ValidationException,
224261
),
225-
max_time=5,
262+
max_time=2,
226263
)
227-
def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED):
264+
def call(
265+
self,
266+
api_kwargs: Dict = {},
267+
model_type: ModelType = ModelType.UNDEFINED,
268+
) -> dict:
228269
"""
229270
kwargs is the combined input and model_kwargs
230271
"""
231272
if model_type == ModelType.LLM:
232-
return self.sync_client.converse(**api_kwargs)
273+
if "stream" in api_kwargs and api_kwargs.get("stream", False):
274+
log.debug("Streaming call")
275+
api_kwargs.pop(
276+
"stream", None
277+
) # stream is not a valid parameter for bedrock
278+
self.chat_completion_parser = self.handle_stream_response
279+
return self.sync_client.converse_stream(**api_kwargs)
280+
else:
281+
api_kwargs.pop("stream", None)
282+
return self.sync_client.converse(**api_kwargs)
233283
else:
234284
raise ValueError(f"model_type {model_type} is not supported")
235285

adalflow/poetry.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

adalflow/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ groq = "^0.9.0"
7575
google-generativeai = "^0.7.2"
7676
anthropic = "^0.31.1"
7777
lancedb = "^0.5.2"
78+
boto3 = "^1.35.19"
79+
7880
# TODO: cant make qdrant work here
7981
# qdrant_client = [
8082
# { version = ">=1.12.2,<2.0.0", optional = true, markers = "python_version >= '3.10'" },
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import unittest
2+
from unittest.mock import Mock, patch
3+
from adalflow.core.types import ModelType, GeneratorOutput
4+
from adalflow.components.model_client import BedrockAPIClient
5+
6+
7+
class TestBedrockClient(unittest.TestCase):
8+
def setUp(self) -> None:
9+
"""Set up mocks and test data.
10+
11+
Mocks the boto3 session and the init_sync_client method. Mocks will create a
12+
mock bedrock client and mock responses that can be reused across tests.
13+
"""
14+
self.session_patcher = patch(
15+
"adalflow.components.model_client.bedrock_client.boto3.Session"
16+
)
17+
self.mock_session = self.session_patcher.start()
18+
self.mock_boto3_client = Mock()
19+
self.mock_session.return_value.client.return_value = self.mock_boto3_client
20+
self.init_sync_patcher = patch.object(BedrockAPIClient, "init_sync_client")
21+
self.mock_init_sync_client = self.init_sync_patcher.start()
22+
self.mock_sync_client = Mock()
23+
self.mock_init_sync_client.return_value = self.mock_sync_client
24+
self.mock_sync_client.converse = Mock()
25+
self.mock_sync_client.converse_stream = Mock()
26+
self.client = BedrockAPIClient()
27+
self.client.sync_client = self.mock_sync_client
28+
29+
self.mock_response = {
30+
"ResponseMetadata": {
31+
"RequestId": "43aec10a-9780-4bd5-abcc-857d12460569",
32+
"HTTPStatusCode": 200,
33+
"HTTPHeaders": {
34+
"date": "Sat, 30 Nov 2024 14:27:44 GMT",
35+
"content-type": "application/json",
36+
"content-length": "273",
37+
"connection": "keep-alive",
38+
"x-amzn-requestid": "43aec10a-9780-4bd5-abcc-857d12460569",
39+
},
40+
"RetryAttempts": 0,
41+
},
42+
"output": {
43+
"message": {"role": "assistant", "content": [{"text": "Hello, world!"}]}
44+
},
45+
"stopReason": "end_turn",
46+
"usage": {"inputTokens": 20, "outputTokens": 10, "totalTokens": 30},
47+
"metrics": {"latencyMs": 430},
48+
}
49+
self.mock_stream_response = {
50+
"ResponseMetadata": {
51+
"RequestId": "c76d625e-9fdb-4173-8138-debdd724fc56",
52+
"HTTPStatusCode": 200,
53+
"HTTPHeaders": {
54+
"date": "Sun, 12 Jan 2025 15:10:00 GMT",
55+
"content-type": "application/vnd.amazon.eventstream",
56+
"transfer-encoding": "chunked",
57+
"connection": "keep-alive",
58+
"x-amzn-requestid": "c76d625e-9fdb-4173-8138-debdd724fc56",
59+
},
60+
"RetryAttempts": 0,
61+
},
62+
"stream": iter(()),
63+
}
64+
self.api_kwargs = {
65+
"messages": [{"role": "user", "content": "Hello"}],
66+
"model": "gpt-3.5-turbo",
67+
}
68+
69+
def tearDown(self) -> None:
70+
"""Stop the patchers."""
71+
self.init_sync_patcher.stop()
72+
73+
def test_call(self) -> None:
74+
"""Tests that the call method calls the converse method correctly."""
75+
self.mock_sync_client.converse = Mock(return_value=self.mock_response)
76+
self.mock_sync_client.converse_stream = Mock(return_value=self.mock_response)
77+
78+
result = self.client.call(api_kwargs=self.api_kwargs, model_type=ModelType.LLM)
79+
80+
# Assertions: converse is called once and stream is not called
81+
self.mock_sync_client.converse.assert_called_once_with(**self.api_kwargs)
82+
self.mock_sync_client.converse_stream.assert_not_called()
83+
self.assertEqual(result, self.mock_response)
84+
85+
def test_parse_chat_completion(self) -> None:
86+
"""Tests that the parse_chat_completion method returns expected object."""
87+
output = self.client.parse_chat_completion(completion=self.mock_response)
88+
self.assertTrue(isinstance(output, GeneratorOutput))
89+
self.assertEqual(output.raw_response, "Hello, world!")
90+
self.assertEqual(output.usage.prompt_tokens, 20)
91+
self.assertEqual(output.usage.completion_tokens, 10)
92+
self.assertEqual(output.usage.total_tokens, 30)
93+
94+
def test_parse_chat_completion_call_usage(self) -> None:
95+
"""Test that the parse_chat_completion calls usage completion when not
96+
streaming."""
97+
mock_track_completion_usage = Mock()
98+
self.client.track_completion_usage = mock_track_completion_usage
99+
generator_output = self.client.parse_chat_completion(self.mock_response)
100+
101+
mock_track_completion_usage.assert_called_once()
102+
assert isinstance(generator_output, GeneratorOutput)
103+
104+
def test_streaming_call(self) -> None:
105+
"""Test that a streaming call calls the converse_stream method."""
106+
self.mock_sync_client.converse = Mock(return_value=self.mock_response)
107+
self.mock_sync_client.converse_stream = Mock(return_value=self.mock_response)
108+
109+
# Call the call method.
110+
stream_kwargs = self.api_kwargs | {"stream": True}
111+
self.client.call(api_kwargs=stream_kwargs, model_type=ModelType.LLM)
112+
113+
# Assertions: Streaming method is called
114+
self.mock_sync_client.converse_stream.assert_called_once_with(**stream_kwargs)
115+
self.mock_sync_client.converse.assert_not_called()
116+
117+
def test_call_value_error(self) -> None:
118+
"""Test that a ValueError is raised when an invalid model_type is passed."""
119+
with self.assertRaises(ValueError):
120+
self.client.call(
121+
api_kwargs={},
122+
model_type=ModelType.UNDEFINED, # This should trigger ValueError
123+
)
124+
125+
def test_parse_streaming_chat_completion(self) -> None:
126+
"""Test that the parse_chat_completion does not call usage completion when
127+
streaming."""
128+
mock_track_completion_usage = Mock()
129+
self.client.track_completion_usage = mock_track_completion_usage
130+
131+
self.client.chat_completion_parser = self.client.handle_stream_response
132+
generator_output = self.client.parse_chat_completion(self.mock_stream_response)
133+
134+
mock_track_completion_usage.assert_not_called()
135+
assert isinstance(generator_output, GeneratorOutput)
136+
137+
138+
if __name__ == "__main__":
139+
unittest.main()

0 commit comments

Comments
 (0)