Skip to content

Commit 49ac890

Browse files
author
Lloyd Hamilton
committed
test: add tests for parse_chat_completion method
1 parent 61ef29d commit 49ac890

File tree

2 files changed

+77
-21
lines changed

2 files changed

+77
-21
lines changed

adalflow/adalflow/components/model_client/bedrock_client.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def handle_stream_response(self, stream: dict) -> GeneratorType:
174174
log.debug(f"Raw chunk: {chunk}")
175175
yield chunk
176176
except Exception as e:
177-
print(f"Error in handle_stream_response: {e}") # Debug print
177+
log.debug(f"Error in handle_stream_response: {e}") # Debug print
178178
raise
179179

180180
def parse_chat_completion(self, completion: dict) -> "GeneratorOutput":
@@ -193,7 +193,6 @@ def parse_chat_completion(self, completion: dict) -> "GeneratorOutput":
193193
"""
194194
try:
195195
usage = None
196-
print(completion)
197196
data = self.chat_completion_parser(completion)
198197
if not isinstance(data, GeneratorType):
199198
# Streaming completion usage tracking is not implemented.
Lines changed: 76 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,14 @@
11
import unittest
2-
from unittest.mock import patch, Mock
2+
from unittest.mock import Mock
33

44
# use the openai for mocking standard data types
55

66
from adalflow.core.types import ModelType, GeneratorOutput
77
from adalflow.components.model_client import BedrockAPIClient
88

99

10-
def getenv_side_effect(key):
11-
# This dictionary can hold more keys and values as needed
12-
env_vars = {
13-
"AWS_ACCESS_KEY_ID": "fake_api_key",
14-
"AWS_SECRET_ACCESS_KEY": "fake_api_key",
15-
"AWS_REGION_NAME": "fake_api_key",
16-
}
17-
return env_vars.get(key, None) # Returns None if key is not found
18-
19-
20-
# modified from test_openai_client.py
2110
class TestBedrockClient(unittest.TestCase):
22-
def setUp(self):
11+
def setUp(self) -> None:
2312
self.client = BedrockAPIClient()
2413
self.mock_response = {
2514
"ResponseMetadata": {
@@ -41,21 +30,34 @@ def setUp(self):
4130
"usage": {"inputTokens": 20, "outputTokens": 10, "totalTokens": 30},
4231
"metrics": {"latencyMs": 430},
4332
}
33+
self.mock_stream_response = {
34+
"ResponseMetadata": {
35+
"RequestId": "c76d625e-9fdb-4173-8138-debdd724fc56",
36+
"HTTPStatusCode": 200,
37+
"HTTPHeaders": {
38+
"date": "Sun, 12 Jan 2025 15:10:00 GMT",
39+
"content-type": "application/vnd.amazon.eventstream",
40+
"transfer-encoding": "chunked",
41+
"connection": "keep-alive",
42+
"x-amzn-requestid": "c76d625e-9fdb-4173-8138-debdd724fc56",
43+
},
44+
"RetryAttempts": 0,
45+
},
46+
"stream": iter(()),
47+
}
4448

4549
self.api_kwargs = {
4650
"messages": [{"role": "user", "content": "Hello"}],
4751
"model": "gpt-3.5-turbo",
4852
}
4953

50-
@patch.object(BedrockAPIClient, "init_sync_client")
51-
@patch("adalflow.components.model_client.bedrock_client.boto3")
52-
def test_call(self, MockBedrock, mock_init_sync_client):
54+
def test_call(self) -> None:
55+
"""Test that the converse method is called correctly."""
5356
mock_sync_client = Mock()
54-
MockBedrock.return_value = mock_sync_client
55-
mock_init_sync_client.return_value = mock_sync_client
5657

57-
# Mock the client's api: converse
58+
# Mock the converse API calls.
5859
mock_sync_client.converse = Mock(return_value=self.mock_response)
60+
mock_sync_client.converse_stream = Mock(return_value=self.mock_response)
5961

6062
# Set the sync client
6163
self.client.sync_client = mock_sync_client
@@ -65,6 +67,7 @@ def test_call(self, MockBedrock, mock_init_sync_client):
6567

6668
# Assertions
6769
mock_sync_client.converse.assert_called_once_with(**self.api_kwargs)
70+
mock_sync_client.converse_stream.assert_not_called()
6871
self.assertEqual(result, self.mock_response)
6972

7073
# test parse_chat_completion
@@ -75,6 +78,60 @@ def test_call(self, MockBedrock, mock_init_sync_client):
7578
self.assertEqual(output.usage.completion_tokens, 10)
7679
self.assertEqual(output.usage.total_tokens, 30)
7780

81+
def test_streaming_call(self) -> None:
82+
"""Test that a streaming call calls the converse_stream method."""
83+
mock_sync_client = Mock()
84+
85+
# Mock the converse API calls.
86+
mock_sync_client.converse_stream = Mock(return_value=self.mock_response)
87+
mock_sync_client.converse = Mock(return_value=self.mock_response)
88+
89+
# Set the sync client.
90+
self.client.sync_client = mock_sync_client
91+
92+
# Call the call method.
93+
stream_kwargs = self.api_kwargs | {"stream": True}
94+
self.client.call(api_kwargs=stream_kwargs, model_type=ModelType.LLM)
95+
96+
# Assert the streaming call was made.
97+
mock_sync_client.converse_stream.assert_called_once_with(**stream_kwargs)
98+
mock_sync_client.converse.assert_not_called()
99+
100+
def test_call_value_error(self) -> None:
101+
"""Test that a ValueError is raised when an invalid model_type is passed."""
102+
mock_sync_client = Mock()
103+
104+
# Set the sync client
105+
self.client.sync_client = mock_sync_client
106+
107+
# Test that ValueError is raised
108+
with self.assertRaises(ValueError):
109+
self.client.call(
110+
api_kwargs={},
111+
model_type=ModelType.UNDEFINED, # This should trigger ValueError
112+
)
113+
114+
def test_parse_chat_completion(self) -> None:
115+
"""Test that the parse_chat_completion does not call usage completion when
116+
streaming."""
117+
mock_track_completion_usage = Mock()
118+
self.client.track_completion_usage = mock_track_completion_usage
119+
120+
self.client.chat_completion_parser = self.client.handle_stream_response
121+
generator_output = self.client.parse_chat_completion(self.mock_stream_response)
122+
123+
mock_track_completion_usage.assert_not_called()
124+
assert isinstance(generator_output, GeneratorOutput)
125+
126+
def test_parse_chat_completion_call_usage(self) -> None:
127+
"""Test that the parse_chat_completion calls usage completion when streaming."""
128+
mock_track_completion_usage = Mock()
129+
self.client.track_completion_usage = mock_track_completion_usage
130+
generator_output = self.client.parse_chat_completion(self.mock_response)
131+
132+
mock_track_completion_usage.assert_called_once()
133+
assert isinstance(generator_output, GeneratorOutput)
134+
78135

79136
if __name__ == "__main__":
80137
unittest.main()

0 commit comments

Comments
 (0)