Skip to content

Commit 61ef29d

Browse files
author
Lloyd Hamilton
committed
test: add tests for bedrock client
1 parent f9d46e3 commit 61ef29d

File tree

4 files changed

+118
-16
lines changed

4 files changed

+118
-16
lines changed

adalflow/adalflow/components/model_client/bedrock_client.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,13 @@
11
"""AWS Bedrock ModelClient integration."""
2+
23
import json
34
import os
4-
from typing import (
5-
Dict,
6-
Optional,
7-
Any,
8-
Callable,
9-
Generator as GeneratorType
10-
)
5+
from typing import Dict, Optional, Any, Callable, Generator as GeneratorType
116
import backoff
127
import logging
138

149
from adalflow.core.model_client import ModelClient
1510
from adalflow.core.types import ModelType, CompletionUsage, GeneratorOutput
16-
from adalflow.utils import printc
1711

1812
from adalflow.utils.lazy_import import safe_import, OptionalPackages
1913

@@ -166,21 +160,46 @@ def init_async_client(self):
166160
raise NotImplementedError("Async call not implemented yet.")
167161

168162
def handle_stream_response(self, stream: dict) -> GeneratorType:
163+
r"""Handle the stream response from bedrock. Yield the chunks.
164+
165+
Args:
166+
stream (dict): The stream response generator from bedrock.
167+
168+
Returns:
169+
GeneratorType: A generator that yields the chunks from bedrock stream.
170+
"""
169171
try:
170172
stream: GeneratorType = stream["stream"]
171173
for chunk in stream:
172174
log.debug(f"Raw chunk: {chunk}")
173175
yield chunk
174176
except Exception as e:
175177
print(f"Error in handle_stream_response: {e}") # Debug print
176-
raise from e
178+
raise
177179

178180
def parse_chat_completion(self, completion: dict) -> "GeneratorOutput":
179-
"""Parse the completion, and put it into the raw_response."""
181+
r"""Parse the completion, and assign it into the raw_response attribute.
182+
183+
If the completion is a stream, it will be handled by the handle_stream_response
184+
method that returns a Generator. Otherwise, the completion will be parsed using
185+
the get_first_message_content method.
186+
187+
Args:
188+
completion (dict): The completion response from bedrock API call.
189+
190+
Returns:
191+
GeneratorOutput: A generator output object with the parsed completion. May
192+
return a generator if the completion is a stream.
193+
"""
180194
try:
195+
usage = None
196+
print(completion)
181197
data = self.chat_completion_parser(completion)
198+
if not isinstance(data, GeneratorType):
199+
# Streaming completion usage tracking is not implemented.
200+
usage = self.track_completion_usage(completion)
182201
return GeneratorOutput(
183-
data=None, error=None, raw_response=data
202+
data=None, error=None, raw_response=data, usage=usage
184203
)
185204
except Exception as e:
186205
log.error(f"Error parsing the completion: {e}")
@@ -254,7 +273,9 @@ def call(
254273
if model_type == ModelType.LLM:
255274
if "stream" in api_kwargs and api_kwargs.get("stream", False):
256275
log.debug("Streaming call")
257-
api_kwargs.pop("stream") # stream is not a valid parameter for bedrock
276+
api_kwargs.pop(
277+
"stream", None
278+
) # stream is not a valid parameter for bedrock
258279
self.chat_completion_parser = self.handle_stream_response
259280
return self.sync_client.converse_stream(**api_kwargs)
260281
else:

adalflow/poetry.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

adalflow/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ groq = "^0.9.0"
8080
google-generativeai = "^0.7.2"
8181
anthropic = "^0.31.1"
8282
lancedb = "^0.5.2"
83+
boto3 = "^1.35.19"
8384

8485

8586

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import unittest
2+
from unittest.mock import patch, Mock
3+
4+
# use the openai for mocking standard data types
5+
6+
from adalflow.core.types import ModelType, GeneratorOutput
7+
from adalflow.components.model_client import BedrockAPIClient
8+
9+
10+
def getenv_side_effect(key):
11+
# This dictionary can hold more keys and values as needed
12+
env_vars = {
13+
"AWS_ACCESS_KEY_ID": "fake_api_key",
14+
"AWS_SECRET_ACCESS_KEY": "fake_api_key",
15+
"AWS_REGION_NAME": "fake_api_key",
16+
}
17+
return env_vars.get(key, None) # Returns None if key is not found
18+
19+
20+
# modified from test_openai_client.py
21+
class TestBedrockClient(unittest.TestCase):
22+
def setUp(self):
23+
self.client = BedrockAPIClient()
24+
self.mock_response = {
25+
"ResponseMetadata": {
26+
"RequestId": "43aec10a-9780-4bd5-abcc-857d12460569",
27+
"HTTPStatusCode": 200,
28+
"HTTPHeaders": {
29+
"date": "Sat, 30 Nov 2024 14:27:44 GMT",
30+
"content-type": "application/json",
31+
"content-length": "273",
32+
"connection": "keep-alive",
33+
"x-amzn-requestid": "43aec10a-9780-4bd5-abcc-857d12460569",
34+
},
35+
"RetryAttempts": 0,
36+
},
37+
"output": {
38+
"message": {"role": "assistant", "content": [{"text": "Hello, world!"}]}
39+
},
40+
"stopReason": "end_turn",
41+
"usage": {"inputTokens": 20, "outputTokens": 10, "totalTokens": 30},
42+
"metrics": {"latencyMs": 430},
43+
}
44+
45+
self.api_kwargs = {
46+
"messages": [{"role": "user", "content": "Hello"}],
47+
"model": "gpt-3.5-turbo",
48+
}
49+
50+
@patch.object(BedrockAPIClient, "init_sync_client")
51+
@patch("adalflow.components.model_client.bedrock_client.boto3")
52+
def test_call(self, MockBedrock, mock_init_sync_client):
53+
mock_sync_client = Mock()
54+
MockBedrock.return_value = mock_sync_client
55+
mock_init_sync_client.return_value = mock_sync_client
56+
57+
# Mock the client's api: converse
58+
mock_sync_client.converse = Mock(return_value=self.mock_response)
59+
60+
# Set the sync client
61+
self.client.sync_client = mock_sync_client
62+
63+
# Call the call method
64+
result = self.client.call(api_kwargs=self.api_kwargs, model_type=ModelType.LLM)
65+
66+
# Assertions
67+
mock_sync_client.converse.assert_called_once_with(**self.api_kwargs)
68+
self.assertEqual(result, self.mock_response)
69+
70+
# test parse_chat_completion
71+
output = self.client.parse_chat_completion(completion=self.mock_response)
72+
self.assertTrue(isinstance(output, GeneratorOutput))
73+
self.assertEqual(output.raw_response, "Hello, world!")
74+
self.assertEqual(output.usage.prompt_tokens, 20)
75+
self.assertEqual(output.usage.completion_tokens, 10)
76+
self.assertEqual(output.usage.total_tokens, 30)
77+
78+
79+
if __name__ == "__main__":
80+
unittest.main()

0 commit comments

Comments
 (0)