|
| 1 | +import unittest |
| 2 | +from unittest.mock import Mock, patch |
| 3 | +from adalflow.core.types import ModelType, GeneratorOutput |
| 4 | +from adalflow.components.model_client import BedrockAPIClient |
| 5 | + |
| 6 | + |
| 7 | +class TestBedrockClient(unittest.TestCase): |
| 8 | + def setUp(self) -> None: |
| 9 | + """Set up mocks and test data. |
| 10 | +
|
| 11 | + Mocks the boto3 session and the init_sync_client method. Mocks will create a |
| 12 | + mock bedrock client and mock responses that can be reused across tests. |
| 13 | + """ |
| 14 | + self.session_patcher = patch( |
| 15 | + "adalflow.components.model_client.bedrock_client.boto3.Session" |
| 16 | + ) |
| 17 | + self.mock_session = self.session_patcher.start() |
| 18 | + self.mock_boto3_client = Mock() |
| 19 | + self.mock_session.return_value.client.return_value = self.mock_boto3_client |
| 20 | + self.init_sync_patcher = patch.object(BedrockAPIClient, "init_sync_client") |
| 21 | + self.mock_init_sync_client = self.init_sync_patcher.start() |
| 22 | + self.mock_sync_client = Mock() |
| 23 | + self.mock_init_sync_client.return_value = self.mock_sync_client |
| 24 | + self.mock_sync_client.converse = Mock() |
| 25 | + self.mock_sync_client.converse_stream = Mock() |
| 26 | + self.client = BedrockAPIClient() |
| 27 | + self.client.sync_client = self.mock_sync_client |
| 28 | + |
| 29 | + self.mock_response = { |
| 30 | + "ResponseMetadata": { |
| 31 | + "RequestId": "43aec10a-9780-4bd5-abcc-857d12460569", |
| 32 | + "HTTPStatusCode": 200, |
| 33 | + "HTTPHeaders": { |
| 34 | + "date": "Sat, 30 Nov 2024 14:27:44 GMT", |
| 35 | + "content-type": "application/json", |
| 36 | + "content-length": "273", |
| 37 | + "connection": "keep-alive", |
| 38 | + "x-amzn-requestid": "43aec10a-9780-4bd5-abcc-857d12460569", |
| 39 | + }, |
| 40 | + "RetryAttempts": 0, |
| 41 | + }, |
| 42 | + "output": { |
| 43 | + "message": {"role": "assistant", "content": [{"text": "Hello, world!"}]} |
| 44 | + }, |
| 45 | + "stopReason": "end_turn", |
| 46 | + "usage": {"inputTokens": 20, "outputTokens": 10, "totalTokens": 30}, |
| 47 | + "metrics": {"latencyMs": 430}, |
| 48 | + } |
| 49 | + self.mock_stream_response = { |
| 50 | + "ResponseMetadata": { |
| 51 | + "RequestId": "c76d625e-9fdb-4173-8138-debdd724fc56", |
| 52 | + "HTTPStatusCode": 200, |
| 53 | + "HTTPHeaders": { |
| 54 | + "date": "Sun, 12 Jan 2025 15:10:00 GMT", |
| 55 | + "content-type": "application/vnd.amazon.eventstream", |
| 56 | + "transfer-encoding": "chunked", |
| 57 | + "connection": "keep-alive", |
| 58 | + "x-amzn-requestid": "c76d625e-9fdb-4173-8138-debdd724fc56", |
| 59 | + }, |
| 60 | + "RetryAttempts": 0, |
| 61 | + }, |
| 62 | + "stream": iter(()), |
| 63 | + } |
| 64 | + self.api_kwargs = { |
| 65 | + "messages": [{"role": "user", "content": "Hello"}], |
| 66 | + "model": "gpt-3.5-turbo", |
| 67 | + } |
| 68 | + |
| 69 | + def tearDown(self) -> None: |
| 70 | + """Stop the patchers.""" |
| 71 | + self.init_sync_patcher.stop() |
| 72 | + |
| 73 | + def test_call(self) -> None: |
| 74 | + """Tests that the call method calls the converse method correctly.""" |
| 75 | + self.mock_sync_client.converse = Mock(return_value=self.mock_response) |
| 76 | + self.mock_sync_client.converse_stream = Mock(return_value=self.mock_response) |
| 77 | + |
| 78 | + result = self.client.call(api_kwargs=self.api_kwargs, model_type=ModelType.LLM) |
| 79 | + |
| 80 | + # Assertions: converse is called once and stream is not called |
| 81 | + self.mock_sync_client.converse.assert_called_once_with(**self.api_kwargs) |
| 82 | + self.mock_sync_client.converse_stream.assert_not_called() |
| 83 | + self.assertEqual(result, self.mock_response) |
| 84 | + |
| 85 | + def test_parse_chat_completion(self) -> None: |
| 86 | + """Tests that the parse_chat_completion method returns expected object.""" |
| 87 | + output = self.client.parse_chat_completion(completion=self.mock_response) |
| 88 | + self.assertTrue(isinstance(output, GeneratorOutput)) |
| 89 | + self.assertEqual(output.raw_response, "Hello, world!") |
| 90 | + self.assertEqual(output.usage.prompt_tokens, 20) |
| 91 | + self.assertEqual(output.usage.completion_tokens, 10) |
| 92 | + self.assertEqual(output.usage.total_tokens, 30) |
| 93 | + |
| 94 | + def test_parse_chat_completion_call_usage(self) -> None: |
| 95 | + """Test that the parse_chat_completion calls usage completion when not |
| 96 | + streaming.""" |
| 97 | + mock_track_completion_usage = Mock() |
| 98 | + self.client.track_completion_usage = mock_track_completion_usage |
| 99 | + generator_output = self.client.parse_chat_completion(self.mock_response) |
| 100 | + |
| 101 | + mock_track_completion_usage.assert_called_once() |
| 102 | + assert isinstance(generator_output, GeneratorOutput) |
| 103 | + |
| 104 | + def test_streaming_call(self) -> None: |
| 105 | + """Test that a streaming call calls the converse_stream method.""" |
| 106 | + self.mock_sync_client.converse = Mock(return_value=self.mock_response) |
| 107 | + self.mock_sync_client.converse_stream = Mock(return_value=self.mock_response) |
| 108 | + |
| 109 | + # Call the call method. |
| 110 | + stream_kwargs = self.api_kwargs | {"stream": True} |
| 111 | + self.client.call(api_kwargs=stream_kwargs, model_type=ModelType.LLM) |
| 112 | + |
| 113 | + # Assertions: Streaming method is called |
| 114 | + self.mock_sync_client.converse_stream.assert_called_once_with(**stream_kwargs) |
| 115 | + self.mock_sync_client.converse.assert_not_called() |
| 116 | + |
| 117 | + def test_call_value_error(self) -> None: |
| 118 | + """Test that a ValueError is raised when an invalid model_type is passed.""" |
| 119 | + with self.assertRaises(ValueError): |
| 120 | + self.client.call( |
| 121 | + api_kwargs={}, |
| 122 | + model_type=ModelType.UNDEFINED, # This should trigger ValueError |
| 123 | + ) |
| 124 | + |
| 125 | + def test_parse_streaming_chat_completion(self) -> None: |
| 126 | + """Test that the parse_chat_completion does not call usage completion when |
| 127 | + streaming.""" |
| 128 | + mock_track_completion_usage = Mock() |
| 129 | + self.client.track_completion_usage = mock_track_completion_usage |
| 130 | + |
| 131 | + self.client.chat_completion_parser = self.client.handle_stream_response |
| 132 | + generator_output = self.client.parse_chat_completion(self.mock_stream_response) |
| 133 | + |
| 134 | + mock_track_completion_usage.assert_not_called() |
| 135 | + assert isinstance(generator_output, GeneratorOutput) |
| 136 | + |
| 137 | + |
| 138 | +if __name__ == "__main__": |
| 139 | + unittest.main() |
0 commit comments