Skip to content

Commit 8e38c2e

Browse files
unittestcase fixing test_greeting
1 parent 37906ee commit 8e38c2e

File tree

1 file changed

+24
-7
lines changed

1 file changed

+24
-7
lines changed

src/tests/api/plugins/test_chat_with_data_plugin.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,19 @@ def chat_plugin(mock_config):
2525

2626

2727
class TestChatWithDataPlugin:
28-
@patch("plugins.chat_with_data_plugin.get_bearer_token_provider")
28+
@patch("helpers.azure_openai_helper.Config")
29+
@patch("helpers.azure_openai_helper.get_bearer_token_provider")
2930
@patch("helpers.azure_openai_helper.openai.AzureOpenAI")
3031
@pytest.mark.asyncio
31-
async def test_greeting(self, mock_azure_openai, mock_token_provider, chat_plugin):
32+
async def test_greeting(self, mock_azure_openai, mock_token_provider, mock_config, chat_plugin):
3233
# Setup mock token provider
3334
mock_token_provider.return_value = lambda: "fake_token"
3435

3536
# Setup mock client and completion response
37+
mock_config_instance = MagicMock()
38+
mock_config_instance.azure_openai_endpoint = "https://test-openai.azure.com/"
39+
mock_config_instance.azure_openai_api_version = "2024-02-15-preview"
40+
mock_config.return_value = mock_config_instance
3641
mock_client = MagicMock()
3742
mock_completion = MagicMock()
3843
mock_completion.choices = [MagicMock()]
@@ -106,12 +111,17 @@ async def test_greeting_exception(self, mock_azure_openai, chat_plugin):
106111
assert result == "Details could not be retrieved. Please try again later."
107112

108113
@pytest.mark.asyncio
109-
@patch("plugins.chat_with_data_plugin.get_bearer_token_provider")
114+
@patch("helpers.azure_openai_helper.Config")
115+
@patch("helpers.azure_openai_helper.get_bearer_token_provider")
110116
@patch("plugins.chat_with_data_plugin.execute_sql_query")
111117
@patch("helpers.azure_openai_helper.openai.AzureOpenAI")
112-
async def test_get_SQL_Response(self, mock_azure_openai, mock_execute_sql, mock_token_provider, chat_plugin):
118+
async def test_get_SQL_Response(self, mock_azure_openai, mock_execute_sql, mock_token_provider, mock_config, chat_plugin):
113119

114120
# Setup mocks
121+
mock_config_instance = MagicMock()
122+
mock_config_instance.azure_openai_endpoint = "https://test-openai.azure.com/"
123+
mock_config_instance.azure_openai_api_version = "2024-02-15-preview"
124+
mock_config.return_value = mock_config_instance
115125
mock_token_provider.return_value = lambda: "fake_token"
116126
mock_client = MagicMock()
117127
mock_azure_openai.return_value = mock_client
@@ -136,6 +146,7 @@ async def test_get_SQL_Response(self, mock_azure_openai, mock_execute_sql, mock_
136146
mock_execute_sql.assert_called_once_with("SELECT * FROM km_processed_data")
137147

138148
@pytest.mark.asyncio
149+
@patch("helpers.azure_openai_helper.Config")
139150
@patch("plugins.chat_with_data_plugin.execute_sql_query")
140151
@patch("plugins.chat_with_data_plugin.AIProjectClient")
141152
@patch("plugins.chat_with_data_plugin.DefaultAzureCredential")
@@ -144,8 +155,12 @@ async def test_get_SQL_Response_with_ai_project_client(self, mock_azure_credenti
144155
chat_plugin.use_ai_project_client = True
145156

146157
# Setup mocks
158+
mock_config_instance = MagicMock()
159+
mock_config_instance.ai_project_endpoint = "https://test-openai.azure.com/"
160+
mock_config.return_value = mock_config_instance
147161
mock_project = MagicMock()
148-
mock_ai_project_client.from_connection_string.return_value = mock_project
162+
mock_ai_project_client = MagicMock()
163+
mock_ai_project_client.return_value = mock_project
149164
mock_client = MagicMock()
150165
mock_project.inference.get_chat_completions_client.return_value = mock_client
151166
mock_completion = MagicMock()
@@ -160,8 +175,10 @@ async def test_get_SQL_Response_with_ai_project_client(self, mock_azure_credenti
160175

161176
# Assertions
162177
assert result == "Query results data with AI Project Client"
163-
mock_ai_project_client.from_connection_string.assert_called_once()
164-
mock_client.complete.assert_called_once()
178+
mock_client.assert_called_once_with(
179+
endpoint="https://test-openai.azure.com/",
180+
credential=mock_azure_credential,
181+
)
165182
mock_execute_sql.assert_called_once_with("\nSELECT * FROM km_processed_data\n")
166183

167184
@pytest.mark.asyncio

0 commit comments

Comments
 (0)