Skip to content

Commit 8f23132

Browse files
committed
test: add aws bedrock unit test
1 parent 7e98938 commit 8f23132

File tree

1 file changed

+82
-0
lines changed

1 file changed

+82
-0
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import unittest
2+
from unittest.mock import patch, Mock
3+
4+
# use the openai for mocking standard data types
5+
from openai.types import CompletionUsage
6+
from openai.types.chat import ChatCompletion
7+
8+
from adalflow.core.types import ModelType, GeneratorOutput
9+
from adalflow.components.model_client import BedrockAPIClient
10+
11+
12+
def getenv_side_effect(key):
13+
# This dictionary can hold more keys and values as needed
14+
env_vars = {
15+
"AWS_ACCESS_KEY_ID": "fake_api_key",
16+
"AWS_SECRET_ACCESS_KEY": "fake_api_key",
17+
"AWS_REGION_NAME": "fake_api_key",
18+
}
19+
return env_vars.get(key, None) # Returns None if key is not found
20+
21+
22+
# modified from test_openai_client.py
23+
class TestBedrockClient(unittest.TestCase):
24+
def setUp(self):
25+
self.client = BedrockAPIClient()
26+
self.mock_response = {
27+
"ResponseMetadata": {
28+
"RequestId": "43aec10a-9780-4bd5-abcc-857d12460569",
29+
"HTTPStatusCode": 200,
30+
"HTTPHeaders": {
31+
"date": "Sat, 30 Nov 2024 14:27:44 GMT",
32+
"content-type": "application/json",
33+
"content-length": "273",
34+
"connection": "keep-alive",
35+
"x-amzn-requestid": "43aec10a-9780-4bd5-abcc-857d12460569",
36+
},
37+
"RetryAttempts": 0,
38+
},
39+
"output": {
40+
"message": {"role": "assistant", "content": [{"text": "Hello, world!"}]}
41+
},
42+
"stopReason": "end_turn",
43+
"usage": {"inputTokens": 20, "outputTokens": 10, "totalTokens": 30},
44+
"metrics": {"latencyMs": 430},
45+
}
46+
47+
self.api_kwargs = {
48+
"messages": [{"role": "user", "content": "Hello"}],
49+
"model": "gpt-3.5-turbo",
50+
}
51+
52+
@patch.object(BedrockAPIClient, "init_sync_client")
53+
@patch("adalflow.components.model_client.bedrock_client.boto3")
54+
def test_call(self, MockBedrock, mock_init_sync_client):
55+
mock_sync_client = Mock()
56+
MockBedrock.return_value = mock_sync_client
57+
mock_init_sync_client.return_value = mock_sync_client
58+
59+
# Mock the client's api: converse
60+
mock_sync_client.converse = Mock(return_value=self.mock_response)
61+
62+
# Set the sync client
63+
self.client.sync_client = mock_sync_client
64+
65+
# Call the call method
66+
result = self.client.call(api_kwargs=self.api_kwargs, model_type=ModelType.LLM)
67+
68+
# Assertions
69+
mock_sync_client.converse.assert_called_once_with(**self.api_kwargs)
70+
self.assertEqual(result, self.mock_response)
71+
72+
# test parse_chat_completion
73+
output = self.client.parse_chat_completion(completion=self.mock_response)
74+
self.assertTrue(isinstance(output, GeneratorOutput))
75+
self.assertEqual(output.raw_response, "Hello, world!")
76+
self.assertEqual(output.usage.prompt_tokens, 20)
77+
self.assertEqual(output.usage.completion_tokens, 10)
78+
self.assertEqual(output.usage.total_tokens, 30)
79+
80+
81+
if __name__ == "__main__":
82+
unittest.main()

0 commit comments

Comments
 (0)