11import unittest
2- from unittest .mock import Mock
2+ from unittest .mock import Mock , patch
33
44# use the openai for mocking standard data types
55
99
1010class TestBedrockClient (unittest .TestCase ):
1111 def setUp (self ) -> None :
12+ """Set up mocks and test data.
13+
14+ Mocks the boto3 session and the init_sync_client method. Mocks will create a
15+ mock bedrock client and mock responses that can be reused across tests.
16+ """
17+ self .session_patcher = patch (
18+ "adalflow.components.model_client.bedrock_client.boto3.Session"
19+ )
20+ self .mock_session = self .session_patcher .start ()
21+ self .mock_boto3_client = Mock ()
22+ self .mock_session .return_value .client .return_value = self .mock_boto3_client
23+ self .init_sync_patcher = patch .object (BedrockAPIClient , "init_sync_client" )
24+ self .mock_init_sync_client = self .init_sync_patcher .start ()
25+ self .mock_sync_client = Mock ()
26+ self .mock_init_sync_client .return_value = self .mock_sync_client
27+ self .mock_sync_client .converse = Mock ()
28+ self .mock_sync_client .converse_stream = Mock ()
1229 self .client = BedrockAPIClient ()
30+ self .client .sync_client = self .mock_sync_client
31+
1332 self .mock_response = {
1433 "ResponseMetadata" : {
1534 "RequestId" : "43aec10a-9780-4bd5-abcc-857d12460569" ,
@@ -45,73 +64,68 @@ def setUp(self) -> None:
4564 },
4665 "stream" : iter (()),
4766 }
48-
4967 self .api_kwargs = {
5068 "messages" : [{"role" : "user" , "content" : "Hello" }],
5169 "model" : "gpt-3.5-turbo" ,
5270 }
5371
54- def test_call (self ) -> None :
55- """Test that the converse method is called correctly."""
56- mock_sync_client = Mock ()
57-
58- # Mock the converse API calls.
59- mock_sync_client .converse = Mock (return_value = self .mock_response )
60- mock_sync_client .converse_stream = Mock (return_value = self .mock_response )
72+ def tearDown (self ) -> None :
73+ """Stop the patchers."""
74+ self .init_sync_patcher .stop ()
6175
62- # Set the sync client
63- self .client .sync_client = mock_sync_client
76+ def test_call (self ) -> None :
77+ """Tests that the call method calls the converse method correctly."""
78+ self .mock_sync_client .converse = Mock (return_value = self .mock_response )
79+ self .mock_sync_client .converse_stream = Mock (return_value = self .mock_response )
6480
65- # Call the call method
6681 result = self .client .call (api_kwargs = self .api_kwargs , model_type = ModelType .LLM )
6782
68- # Assertions
69- mock_sync_client .converse .assert_called_once_with (** self .api_kwargs )
70- mock_sync_client .converse_stream .assert_not_called ()
83+ # Assertions: converse is called once and stream is not called
84+ self . mock_sync_client .converse .assert_called_once_with (** self .api_kwargs )
85+ self . mock_sync_client .converse_stream .assert_not_called ()
7186 self .assertEqual (result , self .mock_response )
7287
73- # test parse_chat_completion
88+ def test_parse_chat_completion (self ) -> None :
89+ """Tests that the parse_chat_completion method returns expected object."""
7490 output = self .client .parse_chat_completion (completion = self .mock_response )
7591 self .assertTrue (isinstance (output , GeneratorOutput ))
7692 self .assertEqual (output .raw_response , "Hello, world!" )
7793 self .assertEqual (output .usage .prompt_tokens , 20 )
7894 self .assertEqual (output .usage .completion_tokens , 10 )
7995 self .assertEqual (output .usage .total_tokens , 30 )
8096
81- def test_streaming_call (self ) -> None :
82- """Test that a streaming call calls the converse_stream method."""
83- mock_sync_client = Mock ()
97+ def test_parse_chat_completion_call_usage (self ) -> None :
98+ """Test that the parse_chat_completion calls usage completion when not
99+ streaming."""
100+ mock_track_completion_usage = Mock ()
101+ self .client .track_completion_usage = mock_track_completion_usage
102+ generator_output = self .client .parse_chat_completion (self .mock_response )
84103
85- # Mock the converse API calls.
86- mock_sync_client .converse_stream = Mock (return_value = self .mock_response )
87- mock_sync_client .converse = Mock (return_value = self .mock_response )
104+ mock_track_completion_usage .assert_called_once ()
105+ assert isinstance (generator_output , GeneratorOutput )
88106
89- # Set the sync client.
90- self .client .sync_client = mock_sync_client
107+ def test_streaming_call (self ) -> None :
108+ """Test that a streaming call calls the converse_stream method."""
109+ self .mock_sync_client .converse = Mock (return_value = self .mock_response )
110+ self .mock_sync_client .converse_stream = Mock (return_value = self .mock_response )
91111
92112 # Call the call method.
93113 stream_kwargs = self .api_kwargs | {"stream" : True }
94114 self .client .call (api_kwargs = stream_kwargs , model_type = ModelType .LLM )
95115
96- # Assert the streaming call was made.
97- mock_sync_client .converse_stream .assert_called_once_with (** stream_kwargs )
98- mock_sync_client .converse .assert_not_called ()
116+ # Assertions: Streaming method is called
117+ self . mock_sync_client .converse_stream .assert_called_once_with (** stream_kwargs )
118+ self . mock_sync_client .converse .assert_not_called ()
99119
100120 def test_call_value_error (self ) -> None :
101121 """Test that a ValueError is raised when an invalid model_type is passed."""
102- mock_sync_client = Mock ()
103-
104- # Set the sync client
105- self .client .sync_client = mock_sync_client
106-
107- # Test that ValueError is raised
108122 with self .assertRaises (ValueError ):
109123 self .client .call (
110124 api_kwargs = {},
111125 model_type = ModelType .UNDEFINED , # This should trigger ValueError
112126 )
113127
114- def test_parse_chat_completion (self ) -> None :
128+ def test_parse_streaming_chat_completion (self ) -> None :
115129 """Test that the parse_chat_completion does not call usage completion when
116130 streaming."""
117131 mock_track_completion_usage = Mock ()
@@ -123,15 +137,138 @@ def test_parse_chat_completion(self) -> None:
123137 mock_track_completion_usage .assert_not_called ()
124138 assert isinstance (generator_output , GeneratorOutput )
125139
126- def test_parse_chat_completion_call_usage (self ) -> None :
127- """Test that the parse_chat_completion calls usage completion when streaming."""
128- mock_track_completion_usage = Mock ()
129- self .client .track_completion_usage = mock_track_completion_usage
130- generator_output = self .client .parse_chat_completion (self .mock_response )
131-
132- mock_track_completion_usage .assert_called_once ()
133- assert isinstance (generator_output , GeneratorOutput )
134-
135140
136141if __name__ == "__main__" :
137142 unittest .main ()
143+
144+ #
145+ # class TestBedrockClient(unittest.TestCase):
146+ # def setUp(self) -> None:
147+ #
148+ # # Setup the mock client
149+ # self.patcher = patch.object(BedrockAPIClient, "init_sync_client")
150+ # self.mock_init_sync_client = self.patcher.start()
151+ # self.mock_sync_client = Mock()
152+ # self.mock_init_sync_client.return_value = self.mock_sync_client
153+ # self.client = BedrockAPIClient()
154+ # self.client.sync_client = self.mock_sync_client
155+ # self.mock_response = {
156+ # "ResponseMetadata": {
157+ # "RequestId": "43aec10a-9780-4bd5-abcc-857d12460569",
158+ # "HTTPStatusCode": 200,
159+ # "HTTPHeaders": {
160+ # "date": "Sat, 30 Nov 2024 14:27:44 GMT",
161+ # "content-type": "application/json",
162+ # "content-length": "273",
163+ # "connection": "keep-alive",
164+ # "x-amzn-requestid": "43aec10a-9780-4bd5-abcc-857d12460569",
165+ # },
166+ # "RetryAttempts": 0,
167+ # },
168+ # "output": {
169+ # "message": {"role": "assistant", "content": [{"text": "Hello, world!"}]}
170+ # },
171+ # "stopReason": "end_turn",
172+ # "usage": {"inputTokens": 20, "outputTokens": 10, "totalTokens": 30},
173+ # "metrics": {"latencyMs": 430},
174+ # }
175+ # self.mock_stream_response = {
176+ # "ResponseMetadata": {
177+ # "RequestId": "c76d625e-9fdb-4173-8138-debdd724fc56",
178+ # "HTTPStatusCode": 200,
179+ # "HTTPHeaders": {
180+ # "date": "Sun, 12 Jan 2025 15:10:00 GMT",
181+ # "content-type": "application/vnd.amazon.eventstream",
182+ # "transfer-encoding": "chunked",
183+ # "connection": "keep-alive",
184+ # "x-amzn-requestid": "c76d625e-9fdb-4173-8138-debdd724fc56",
185+ # },
186+ # "RetryAttempts": 0,
187+ # },
188+ # "stream": iter(()),
189+ # }
190+ # self.api_kwargs = {
191+ # "messages": [{"role": "user", "content": "Hello"}],
192+ # "model": "gpt-3.5-turbo",
193+ # }
194+ #
195+ # def test_call(self) -> None:
196+ #
197+ # # Mock the converse API calls.
198+ # self.mock_sync_client.converse = Mock(return_value=self.mock_response)
199+ # self.mock_sync_client.converse_stream = Mock(return_value=self.mock_response)
200+ #
201+ # # Call the call method
202+ # result = self.client.call(api_kwargs=self.api_kwargs, model_type=ModelType.LLM)
203+ #
204+ # # Assertions
205+ # self.mock_sync_client.converse.assert_called_once_with(**self.api_kwargs)
206+ # self.mock_sync_client.converse_stream.assert_not_called()
207+ # self.assertEqual(result, self.mock_response)
208+ #
209+ # # test parse_chat_completion
210+ # output = self.client.parse_chat_completion(completion=self.mock_response)
211+ # self.assertTrue(isinstance(output, GeneratorOutput))
212+ # self.assertEqual(output.raw_response, "Hello, world!")
213+ # self.assertEqual(output.usage.prompt_tokens, 20)
214+ # self.assertEqual(output.usage.completion_tokens, 10)
215+ # self.assertEqual(output.usage.total_tokens, 30)
216+ #
217+ #
218+ # # def test_streaming_call(self) -> None:
219+ # # """Test that a streaming call calls the converse_stream method."""
220+ # # mock_sync_client = Mock()
221+ # #
222+ # # # Mock the converse API calls.
223+ # # mock_sync_client.converse_stream = Mock(return_value=self.mock_response)
224+ # # mock_sync_client.converse = Mock(return_value=self.mock_response)
225+ # #
226+ # # # Set the sync client.
227+ # # self.client.sync_client = mock_sync_client
228+ # #
229+ # # # Call the call method.
230+ # # stream_kwargs = self.api_kwargs | {"stream": True}
231+ # # self.client.call(api_kwargs=stream_kwargs, model_type=ModelType.LLM)
232+ # #
233+ # # # Assert the streaming call was made.
234+ # # mock_sync_client.converse_stream.assert_called_once_with(**stream_kwargs)
235+ # # mock_sync_client.converse.assert_not_called()
236+ # #
237+ # # def test_call_value_error(self) -> None:
238+ # # """Test that a ValueError is raised when an invalid model_type is passed."""
239+ # # mock_sync_client = Mock()
240+ # #
241+ # # # Set the sync client
242+ # # self.client.sync_client = mock_sync_client
243+ # #
244+ # # # Test that ValueError is raised
245+ # # with self.assertRaises(ValueError):
246+ # # self.client.call(
247+ # # api_kwargs={},
248+ # # model_type=ModelType.UNDEFINED, # This should trigger ValueError
249+ # # )
250+ # #
251+ # # def test_parse_chat_completion(self) -> None:
252+ # # """Test that the parse_chat_completion does not call usage completion when
253+ # # streaming."""
254+ # # mock_track_completion_usage = Mock()
255+ # # self.client.track_completion_usage = mock_track_completion_usage
256+ # #
257+ # # self.client.chat_completion_parser = self.client.handle_stream_response
258+ # # generator_output = self.client.parse_chat_completion(self.mock_stream_response)
259+ # #
260+ # # mock_track_completion_usage.assert_not_called()
261+ # # assert isinstance(generator_output, GeneratorOutput)
262+ # #
263+ # # def test_parse_chat_completion_call_usage(self) -> None:
264+ # # """Test that the parse_chat_completion calls usage completion when streaming."""
265+ # # mock_track_completion_usage = Mock()
266+ # # self.client.track_completion_usage = mock_track_completion_usage
267+ # # generator_output = self.client.parse_chat_completion(self.mock_response)
268+ # #
269+ # # mock_track_completion_usage.assert_called_once()
270+ # # assert isinstance(generator_output, GeneratorOutput)
271+ # #
272+ # #
273+ # # if __name__ == "__main__":
274+ # # unittest.main()
0 commit comments