Skip to content

Commit 0e60f42

Browse files
author
Lloyd Hamilton
committed
fix: configured mocks for AWS boto3 calls
1 parent 49ac890 commit 0e60f42

File tree

1 file changed

+180
-43
lines changed

1 file changed

+180
-43
lines changed
Lines changed: 180 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import unittest
2-
from unittest.mock import Mock
2+
from unittest.mock import Mock, patch
33

44
# use the openai for mocking standard data types
55

@@ -9,7 +9,26 @@
99

1010
class TestBedrockClient(unittest.TestCase):
1111
def setUp(self) -> None:
12+
"""Set up mocks and test data.
13+
14+
Mocks the boto3 session and the init_sync_client method. Mocks will create a
15+
mock bedrock client and mock responses that can be reused across tests.
16+
"""
17+
self.session_patcher = patch(
18+
"adalflow.components.model_client.bedrock_client.boto3.Session"
19+
)
20+
self.mock_session = self.session_patcher.start()
21+
self.mock_boto3_client = Mock()
22+
self.mock_session.return_value.client.return_value = self.mock_boto3_client
23+
self.init_sync_patcher = patch.object(BedrockAPIClient, "init_sync_client")
24+
self.mock_init_sync_client = self.init_sync_patcher.start()
25+
self.mock_sync_client = Mock()
26+
self.mock_init_sync_client.return_value = self.mock_sync_client
27+
self.mock_sync_client.converse = Mock()
28+
self.mock_sync_client.converse_stream = Mock()
1229
self.client = BedrockAPIClient()
30+
self.client.sync_client = self.mock_sync_client
31+
1332
self.mock_response = {
1433
"ResponseMetadata": {
1534
"RequestId": "43aec10a-9780-4bd5-abcc-857d12460569",
@@ -45,73 +64,68 @@ def setUp(self) -> None:
4564
},
4665
"stream": iter(()),
4766
}
48-
4967
self.api_kwargs = {
5068
"messages": [{"role": "user", "content": "Hello"}],
5169
"model": "gpt-3.5-turbo",
5270
}
5371

54-
def test_call(self) -> None:
55-
"""Test that the converse method is called correctly."""
56-
mock_sync_client = Mock()
57-
58-
# Mock the converse API calls.
59-
mock_sync_client.converse = Mock(return_value=self.mock_response)
60-
mock_sync_client.converse_stream = Mock(return_value=self.mock_response)
72+
def tearDown(self) -> None:
73+
"""Stop the patchers."""
74+
self.init_sync_patcher.stop()
6175

62-
# Set the sync client
63-
self.client.sync_client = mock_sync_client
76+
def test_call(self) -> None:
77+
"""Tests that the call method calls the converse method correctly."""
78+
self.mock_sync_client.converse = Mock(return_value=self.mock_response)
79+
self.mock_sync_client.converse_stream = Mock(return_value=self.mock_response)
6480

65-
# Call the call method
6681
result = self.client.call(api_kwargs=self.api_kwargs, model_type=ModelType.LLM)
6782

68-
# Assertions
69-
mock_sync_client.converse.assert_called_once_with(**self.api_kwargs)
70-
mock_sync_client.converse_stream.assert_not_called()
83+
# Assertions: converse is called once and stream is not called
84+
self.mock_sync_client.converse.assert_called_once_with(**self.api_kwargs)
85+
self.mock_sync_client.converse_stream.assert_not_called()
7186
self.assertEqual(result, self.mock_response)
7287

73-
# test parse_chat_completion
88+
def test_parse_chat_completion(self) -> None:
89+
"""Tests that the parse_chat_completion method returns expected object."""
7490
output = self.client.parse_chat_completion(completion=self.mock_response)
7591
self.assertTrue(isinstance(output, GeneratorOutput))
7692
self.assertEqual(output.raw_response, "Hello, world!")
7793
self.assertEqual(output.usage.prompt_tokens, 20)
7894
self.assertEqual(output.usage.completion_tokens, 10)
7995
self.assertEqual(output.usage.total_tokens, 30)
8096

81-
def test_streaming_call(self) -> None:
82-
"""Test that a streaming call calls the converse_stream method."""
83-
mock_sync_client = Mock()
97+
def test_parse_chat_completion_call_usage(self) -> None:
98+
"""Test that the parse_chat_completion calls usage completion when not
99+
streaming."""
100+
mock_track_completion_usage = Mock()
101+
self.client.track_completion_usage = mock_track_completion_usage
102+
generator_output = self.client.parse_chat_completion(self.mock_response)
84103

85-
# Mock the converse API calls.
86-
mock_sync_client.converse_stream = Mock(return_value=self.mock_response)
87-
mock_sync_client.converse = Mock(return_value=self.mock_response)
104+
mock_track_completion_usage.assert_called_once()
105+
assert isinstance(generator_output, GeneratorOutput)
88106

89-
# Set the sync client.
90-
self.client.sync_client = mock_sync_client
107+
def test_streaming_call(self) -> None:
108+
"""Test that a streaming call calls the converse_stream method."""
109+
self.mock_sync_client.converse = Mock(return_value=self.mock_response)
110+
self.mock_sync_client.converse_stream = Mock(return_value=self.mock_response)
91111

92112
# Call the call method.
93113
stream_kwargs = self.api_kwargs | {"stream": True}
94114
self.client.call(api_kwargs=stream_kwargs, model_type=ModelType.LLM)
95115

96-
# Assert the streaming call was made.
97-
mock_sync_client.converse_stream.assert_called_once_with(**stream_kwargs)
98-
mock_sync_client.converse.assert_not_called()
116+
# Assertions: Streaming method is called
117+
self.mock_sync_client.converse_stream.assert_called_once_with(**stream_kwargs)
118+
self.mock_sync_client.converse.assert_not_called()
99119

100120
def test_call_value_error(self) -> None:
101121
"""Test that a ValueError is raised when an invalid model_type is passed."""
102-
mock_sync_client = Mock()
103-
104-
# Set the sync client
105-
self.client.sync_client = mock_sync_client
106-
107-
# Test that ValueError is raised
108122
with self.assertRaises(ValueError):
109123
self.client.call(
110124
api_kwargs={},
111125
model_type=ModelType.UNDEFINED, # This should trigger ValueError
112126
)
113127

114-
def test_parse_chat_completion(self) -> None:
128+
def test_parse_streaming_chat_completion(self) -> None:
115129
"""Test that the parse_chat_completion does not call usage completion when
116130
streaming."""
117131
mock_track_completion_usage = Mock()
@@ -123,15 +137,138 @@ def test_parse_chat_completion(self) -> None:
123137
mock_track_completion_usage.assert_not_called()
124138
assert isinstance(generator_output, GeneratorOutput)
125139

126-
def test_parse_chat_completion_call_usage(self) -> None:
127-
"""Test that the parse_chat_completion calls usage completion when streaming."""
128-
mock_track_completion_usage = Mock()
129-
self.client.track_completion_usage = mock_track_completion_usage
130-
generator_output = self.client.parse_chat_completion(self.mock_response)
131-
132-
mock_track_completion_usage.assert_called_once()
133-
assert isinstance(generator_output, GeneratorOutput)
134-
135140

136141
if __name__ == "__main__":
137142
unittest.main()
143+
144+
#
145+
# class TestBedrockClient(unittest.TestCase):
146+
# def setUp(self) -> None:
147+
#
148+
# # Setup the mock client
149+
# self.patcher = patch.object(BedrockAPIClient, "init_sync_client")
150+
# self.mock_init_sync_client = self.patcher.start()
151+
# self.mock_sync_client = Mock()
152+
# self.mock_init_sync_client.return_value = self.mock_sync_client
153+
# self.client = BedrockAPIClient()
154+
# self.client.sync_client = self.mock_sync_client
155+
# self.mock_response = {
156+
# "ResponseMetadata": {
157+
# "RequestId": "43aec10a-9780-4bd5-abcc-857d12460569",
158+
# "HTTPStatusCode": 200,
159+
# "HTTPHeaders": {
160+
# "date": "Sat, 30 Nov 2024 14:27:44 GMT",
161+
# "content-type": "application/json",
162+
# "content-length": "273",
163+
# "connection": "keep-alive",
164+
# "x-amzn-requestid": "43aec10a-9780-4bd5-abcc-857d12460569",
165+
# },
166+
# "RetryAttempts": 0,
167+
# },
168+
# "output": {
169+
# "message": {"role": "assistant", "content": [{"text": "Hello, world!"}]}
170+
# },
171+
# "stopReason": "end_turn",
172+
# "usage": {"inputTokens": 20, "outputTokens": 10, "totalTokens": 30},
173+
# "metrics": {"latencyMs": 430},
174+
# }
175+
# self.mock_stream_response = {
176+
# "ResponseMetadata": {
177+
# "RequestId": "c76d625e-9fdb-4173-8138-debdd724fc56",
178+
# "HTTPStatusCode": 200,
179+
# "HTTPHeaders": {
180+
# "date": "Sun, 12 Jan 2025 15:10:00 GMT",
181+
# "content-type": "application/vnd.amazon.eventstream",
182+
# "transfer-encoding": "chunked",
183+
# "connection": "keep-alive",
184+
# "x-amzn-requestid": "c76d625e-9fdb-4173-8138-debdd724fc56",
185+
# },
186+
# "RetryAttempts": 0,
187+
# },
188+
# "stream": iter(()),
189+
# }
190+
# self.api_kwargs = {
191+
# "messages": [{"role": "user", "content": "Hello"}],
192+
# "model": "gpt-3.5-turbo",
193+
# }
194+
#
195+
# def test_call(self) -> None:
196+
#
197+
# # Mock the converse API calls.
198+
# self.mock_sync_client.converse = Mock(return_value=self.mock_response)
199+
# self.mock_sync_client.converse_stream = Mock(return_value=self.mock_response)
200+
#
201+
# # Call the call method
202+
# result = self.client.call(api_kwargs=self.api_kwargs, model_type=ModelType.LLM)
203+
#
204+
# # Assertions
205+
# self.mock_sync_client.converse.assert_called_once_with(**self.api_kwargs)
206+
# self.mock_sync_client.converse_stream.assert_not_called()
207+
# self.assertEqual(result, self.mock_response)
208+
#
209+
# # test parse_chat_completion
210+
# output = self.client.parse_chat_completion(completion=self.mock_response)
211+
# self.assertTrue(isinstance(output, GeneratorOutput))
212+
# self.assertEqual(output.raw_response, "Hello, world!")
213+
# self.assertEqual(output.usage.prompt_tokens, 20)
214+
# self.assertEqual(output.usage.completion_tokens, 10)
215+
# self.assertEqual(output.usage.total_tokens, 30)
216+
#
217+
#
218+
# # def test_streaming_call(self) -> None:
219+
# # """Test that a streaming call calls the converse_stream method."""
220+
# # mock_sync_client = Mock()
221+
# #
222+
# # # Mock the converse API calls.
223+
# # mock_sync_client.converse_stream = Mock(return_value=self.mock_response)
224+
# # mock_sync_client.converse = Mock(return_value=self.mock_response)
225+
# #
226+
# # # Set the sync client.
227+
# # self.client.sync_client = mock_sync_client
228+
# #
229+
# # # Call the call method.
230+
# # stream_kwargs = self.api_kwargs | {"stream": True}
231+
# # self.client.call(api_kwargs=stream_kwargs, model_type=ModelType.LLM)
232+
# #
233+
# # # Assert the streaming call was made.
234+
# # mock_sync_client.converse_stream.assert_called_once_with(**stream_kwargs)
235+
# # mock_sync_client.converse.assert_not_called()
236+
# #
237+
# # def test_call_value_error(self) -> None:
238+
# # """Test that a ValueError is raised when an invalid model_type is passed."""
239+
# # mock_sync_client = Mock()
240+
# #
241+
# # # Set the sync client
242+
# # self.client.sync_client = mock_sync_client
243+
# #
244+
# # # Test that ValueError is raised
245+
# # with self.assertRaises(ValueError):
246+
# # self.client.call(
247+
# # api_kwargs={},
248+
# # model_type=ModelType.UNDEFINED, # This should trigger ValueError
249+
# # )
250+
# #
251+
# # def test_parse_chat_completion(self) -> None:
252+
# # """Test that the parse_chat_completion does not call usage completion when
253+
# # streaming."""
254+
# # mock_track_completion_usage = Mock()
255+
# # self.client.track_completion_usage = mock_track_completion_usage
256+
# #
257+
# # self.client.chat_completion_parser = self.client.handle_stream_response
258+
# # generator_output = self.client.parse_chat_completion(self.mock_stream_response)
259+
# #
260+
# # mock_track_completion_usage.assert_not_called()
261+
# # assert isinstance(generator_output, GeneratorOutput)
262+
# #
263+
# # def test_parse_chat_completion_call_usage(self) -> None:
264+
# # """Test that the parse_chat_completion calls usage completion when streaming."""
265+
# # mock_track_completion_usage = Mock()
266+
# # self.client.track_completion_usage = mock_track_completion_usage
267+
# # generator_output = self.client.parse_chat_completion(self.mock_response)
268+
# #
269+
# # mock_track_completion_usage.assert_called_once()
270+
# # assert isinstance(generator_output, GeneratorOutput)
271+
# #
272+
# #
273+
# # if __name__ == "__main__":
274+
# # unittest.main()

0 commit comments

Comments
 (0)