|
1 | 1 | import unittest |
2 | 2 | from unittest.mock import Mock, patch |
3 | | - |
4 | | -# use the openai for mocking standard data types |
5 | | - |
6 | 3 | from adalflow.core.types import ModelType, GeneratorOutput |
7 | 4 | from adalflow.components.model_client import BedrockAPIClient |
8 | 5 |
|
@@ -140,135 +137,3 @@ def test_parse_streaming_chat_completion(self) -> None: |
140 | 137 |
|
141 | 138 | if __name__ == "__main__": |
142 | 139 | unittest.main() |
143 | | - |
144 | | -# |
145 | | -# class TestBedrockClient(unittest.TestCase): |
146 | | -# def setUp(self) -> None: |
147 | | -# |
148 | | -# # Setup the mock client |
149 | | -# self.patcher = patch.object(BedrockAPIClient, "init_sync_client") |
150 | | -# self.mock_init_sync_client = self.patcher.start() |
151 | | -# self.mock_sync_client = Mock() |
152 | | -# self.mock_init_sync_client.return_value = self.mock_sync_client |
153 | | -# self.client = BedrockAPIClient() |
154 | | -# self.client.sync_client = self.mock_sync_client |
155 | | -# self.mock_response = { |
156 | | -# "ResponseMetadata": { |
157 | | -# "RequestId": "43aec10a-9780-4bd5-abcc-857d12460569", |
158 | | -# "HTTPStatusCode": 200, |
159 | | -# "HTTPHeaders": { |
160 | | -# "date": "Sat, 30 Nov 2024 14:27:44 GMT", |
161 | | -# "content-type": "application/json", |
162 | | -# "content-length": "273", |
163 | | -# "connection": "keep-alive", |
164 | | -# "x-amzn-requestid": "43aec10a-9780-4bd5-abcc-857d12460569", |
165 | | -# }, |
166 | | -# "RetryAttempts": 0, |
167 | | -# }, |
168 | | -# "output": { |
169 | | -# "message": {"role": "assistant", "content": [{"text": "Hello, world!"}]} |
170 | | -# }, |
171 | | -# "stopReason": "end_turn", |
172 | | -# "usage": {"inputTokens": 20, "outputTokens": 10, "totalTokens": 30}, |
173 | | -# "metrics": {"latencyMs": 430}, |
174 | | -# } |
175 | | -# self.mock_stream_response = { |
176 | | -# "ResponseMetadata": { |
177 | | -# "RequestId": "c76d625e-9fdb-4173-8138-debdd724fc56", |
178 | | -# "HTTPStatusCode": 200, |
179 | | -# "HTTPHeaders": { |
180 | | -# "date": "Sun, 12 Jan 2025 15:10:00 GMT", |
181 | | -# "content-type": "application/vnd.amazon.eventstream", |
182 | | -# "transfer-encoding": "chunked", |
183 | | -# "connection": "keep-alive", |
184 | | -# "x-amzn-requestid": "c76d625e-9fdb-4173-8138-debdd724fc56", |
185 | | -# }, |
186 | | -# "RetryAttempts": 0, |
187 | | -# }, |
188 | | -# "stream": iter(()), |
189 | | -# } |
190 | | -# self.api_kwargs = { |
191 | | -# "messages": [{"role": "user", "content": "Hello"}], |
192 | | -# "model": "gpt-3.5-turbo", |
193 | | -# } |
194 | | -# |
195 | | -# def test_call(self) -> None: |
196 | | -# |
197 | | -# # Mock the converse API calls. |
198 | | -# self.mock_sync_client.converse = Mock(return_value=self.mock_response) |
199 | | -# self.mock_sync_client.converse_stream = Mock(return_value=self.mock_response) |
200 | | -# |
201 | | -# # Call the call method |
202 | | -# result = self.client.call(api_kwargs=self.api_kwargs, model_type=ModelType.LLM) |
203 | | -# |
204 | | -# # Assertions |
205 | | -# self.mock_sync_client.converse.assert_called_once_with(**self.api_kwargs) |
206 | | -# self.mock_sync_client.converse_stream.assert_not_called() |
207 | | -# self.assertEqual(result, self.mock_response) |
208 | | -# |
209 | | -# # test parse_chat_completion |
210 | | -# output = self.client.parse_chat_completion(completion=self.mock_response) |
211 | | -# self.assertTrue(isinstance(output, GeneratorOutput)) |
212 | | -# self.assertEqual(output.raw_response, "Hello, world!") |
213 | | -# self.assertEqual(output.usage.prompt_tokens, 20) |
214 | | -# self.assertEqual(output.usage.completion_tokens, 10) |
215 | | -# self.assertEqual(output.usage.total_tokens, 30) |
216 | | -# |
217 | | -# |
218 | | -# # def test_streaming_call(self) -> None: |
219 | | -# # """Test that a streaming call calls the converse_stream method.""" |
220 | | -# # mock_sync_client = Mock() |
221 | | -# # |
222 | | -# # # Mock the converse API calls. |
223 | | -# # mock_sync_client.converse_stream = Mock(return_value=self.mock_response) |
224 | | -# # mock_sync_client.converse = Mock(return_value=self.mock_response) |
225 | | -# # |
226 | | -# # # Set the sync client. |
227 | | -# # self.client.sync_client = mock_sync_client |
228 | | -# # |
229 | | -# # # Call the call method. |
230 | | -# # stream_kwargs = self.api_kwargs | {"stream": True} |
231 | | -# # self.client.call(api_kwargs=stream_kwargs, model_type=ModelType.LLM) |
232 | | -# # |
233 | | -# # # Assert the streaming call was made. |
234 | | -# # mock_sync_client.converse_stream.assert_called_once_with(**stream_kwargs) |
235 | | -# # mock_sync_client.converse.assert_not_called() |
236 | | -# # |
237 | | -# # def test_call_value_error(self) -> None: |
238 | | -# # """Test that a ValueError is raised when an invalid model_type is passed.""" |
239 | | -# # mock_sync_client = Mock() |
240 | | -# # |
241 | | -# # # Set the sync client |
242 | | -# # self.client.sync_client = mock_sync_client |
243 | | -# # |
244 | | -# # # Test that ValueError is raised |
245 | | -# # with self.assertRaises(ValueError): |
246 | | -# # self.client.call( |
247 | | -# # api_kwargs={}, |
248 | | -# # model_type=ModelType.UNDEFINED, # This should trigger ValueError |
249 | | -# # ) |
250 | | -# # |
251 | | -# # def test_parse_chat_completion(self) -> None: |
252 | | -# # """Test that the parse_chat_completion does not call usage completion when |
253 | | -# # streaming.""" |
254 | | -# # mock_track_completion_usage = Mock() |
255 | | -# # self.client.track_completion_usage = mock_track_completion_usage |
256 | | -# # |
257 | | -# # self.client.chat_completion_parser = self.client.handle_stream_response |
258 | | -# # generator_output = self.client.parse_chat_completion(self.mock_stream_response) |
259 | | -# # |
260 | | -# # mock_track_completion_usage.assert_not_called() |
261 | | -# # assert isinstance(generator_output, GeneratorOutput) |
262 | | -# # |
263 | | -# # def test_parse_chat_completion_call_usage(self) -> None: |
264 | | -# # """Test that the parse_chat_completion calls usage completion when streaming.""" |
265 | | -# # mock_track_completion_usage = Mock() |
266 | | -# # self.client.track_completion_usage = mock_track_completion_usage |
267 | | -# # generator_output = self.client.parse_chat_completion(self.mock_response) |
268 | | -# # |
269 | | -# # mock_track_completion_usage.assert_called_once() |
270 | | -# # assert isinstance(generator_output, GeneratorOutput) |
271 | | -# # |
272 | | -# # |
273 | | -# # if __name__ == "__main__": |
274 | | -# # unittest.main() |
0 commit comments