11import unittest
2- from unittest .mock import patch , Mock
2+ from unittest .mock import Mock
33
44# use the openai for mocking standard data types
55
66from adalflow .core .types import ModelType , GeneratorOutput
77from adalflow .components .model_client import BedrockAPIClient
88
99
10- def getenv_side_effect (key ):
11- # This dictionary can hold more keys and values as needed
12- env_vars = {
13- "AWS_ACCESS_KEY_ID" : "fake_api_key" ,
14- "AWS_SECRET_ACCESS_KEY" : "fake_api_key" ,
15- "AWS_REGION_NAME" : "fake_api_key" ,
16- }
17- return env_vars .get (key , None ) # Returns None if key is not found
18-
19-
20- # modified from test_openai_client.py
2110class TestBedrockClient (unittest .TestCase ):
22- def setUp (self ):
11+ def setUp (self ) -> None :
2312 self .client = BedrockAPIClient ()
2413 self .mock_response = {
2514 "ResponseMetadata" : {
@@ -41,21 +30,34 @@ def setUp(self):
4130 "usage" : {"inputTokens" : 20 , "outputTokens" : 10 , "totalTokens" : 30 },
4231 "metrics" : {"latencyMs" : 430 },
4332 }
33+ self .mock_stream_response = {
34+ "ResponseMetadata" : {
35+ "RequestId" : "c76d625e-9fdb-4173-8138-debdd724fc56" ,
36+ "HTTPStatusCode" : 200 ,
37+ "HTTPHeaders" : {
38+ "date" : "Sun, 12 Jan 2025 15:10:00 GMT" ,
39+ "content-type" : "application/vnd.amazon.eventstream" ,
40+ "transfer-encoding" : "chunked" ,
41+ "connection" : "keep-alive" ,
42+ "x-amzn-requestid" : "c76d625e-9fdb-4173-8138-debdd724fc56" ,
43+ },
44+ "RetryAttempts" : 0 ,
45+ },
46+ "stream" : iter (()),
47+ }
4448
4549 self .api_kwargs = {
4650 "messages" : [{"role" : "user" , "content" : "Hello" }],
4751 "model" : "gpt-3.5-turbo" ,
4852 }
4953
50- @patch .object (BedrockAPIClient , "init_sync_client" )
51- @patch ("adalflow.components.model_client.bedrock_client.boto3" )
52- def test_call (self , MockBedrock , mock_init_sync_client ):
54+ def test_call (self ) -> None :
55+ """Test that the converse method is called correctly."""
5356 mock_sync_client = Mock ()
54- MockBedrock .return_value = mock_sync_client
55- mock_init_sync_client .return_value = mock_sync_client
5657
57- # Mock the client's api: converse
58+ # Mock the converse API calls.
5859 mock_sync_client .converse = Mock (return_value = self .mock_response )
60+ mock_sync_client .converse_stream = Mock (return_value = self .mock_response )
5961
6062 # Set the sync client
6163 self .client .sync_client = mock_sync_client
@@ -65,6 +67,7 @@ def test_call(self, MockBedrock, mock_init_sync_client):
6567
6668 # Assertions
6769 mock_sync_client .converse .assert_called_once_with (** self .api_kwargs )
70+ mock_sync_client .converse_stream .assert_not_called ()
6871 self .assertEqual (result , self .mock_response )
6972
7073 # test parse_chat_completion
@@ -75,6 +78,60 @@ def test_call(self, MockBedrock, mock_init_sync_client):
7578 self .assertEqual (output .usage .completion_tokens , 10 )
7679 self .assertEqual (output .usage .total_tokens , 30 )
7780
81+ def test_streaming_call (self ) -> None :
82+ """Test that a streaming call calls the converse_stream method."""
83+ mock_sync_client = Mock ()
84+
85+ # Mock the converse API calls.
86+ mock_sync_client .converse_stream = Mock (return_value = self .mock_response )
87+ mock_sync_client .converse = Mock (return_value = self .mock_response )
88+
89+ # Set the sync client.
90+ self .client .sync_client = mock_sync_client
91+
92+ # Call the call method.
93+ stream_kwargs = self .api_kwargs | {"stream" : True }
94+ self .client .call (api_kwargs = stream_kwargs , model_type = ModelType .LLM )
95+
96+ # Assert the streaming call was made.
97+ mock_sync_client .converse_stream .assert_called_once_with (** stream_kwargs )
98+ mock_sync_client .converse .assert_not_called ()
99+
100+ def test_call_value_error (self ) -> None :
101+ """Test that a ValueError is raised when an invalid model_type is passed."""
102+ mock_sync_client = Mock ()
103+
104+ # Set the sync client
105+ self .client .sync_client = mock_sync_client
106+
107+ # Test that ValueError is raised
108+ with self .assertRaises (ValueError ):
109+ self .client .call (
110+ api_kwargs = {},
111+ model_type = ModelType .UNDEFINED , # This should trigger ValueError
112+ )
113+
114+ def test_parse_chat_completion (self ) -> None :
115+ """Test that the parse_chat_completion does not call usage completion when
116+ streaming."""
117+ mock_track_completion_usage = Mock ()
118+ self .client .track_completion_usage = mock_track_completion_usage
119+
120+ self .client .chat_completion_parser = self .client .handle_stream_response
121+ generator_output = self .client .parse_chat_completion (self .mock_stream_response )
122+
123+ mock_track_completion_usage .assert_not_called ()
124+ assert isinstance (generator_output , GeneratorOutput )
125+
126+ def test_parse_chat_completion_call_usage (self ) -> None :
127+ """Test that the parse_chat_completion calls usage completion when streaming."""
128+ mock_track_completion_usage = Mock ()
129+ self .client .track_completion_usage = mock_track_completion_usage
130+ generator_output = self .client .parse_chat_completion (self .mock_response )
131+
132+ mock_track_completion_usage .assert_called_once ()
133+ assert isinstance (generator_output , GeneratorOutput )
134+
78135
79136if __name__ == "__main__" :
80137 unittest .main ()
0 commit comments