Skip to content

Commit 9e61c8c

Browse files
unittestcase changes for sqlagent
1 parent 53f8e7f commit 9e61c8c

File tree

5 files changed

+260
-199
lines changed

5 files changed

+260
-199
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import pytest
2+
import asyncio
3+
from unittest.mock import AsyncMock, patch, MagicMock
4+
5+
from common.config.config import Config
6+
from agents.agent_factory_base import BaseAgentFactory
7+
8+
9+
class MockAgentFactory(BaseAgentFactory):
10+
"""Concrete test class extending BaseAgentFactory for unit testing."""
11+
_created = False
12+
_deleted = False
13+
14+
@classmethod
15+
async def create_agent(cls, config: Config):
16+
cls._created = True
17+
return {"agent": "mock-agent"}
18+
19+
@classmethod
20+
async def _delete_agent_instance(cls, agent: object):
21+
cls._deleted = True
22+
23+
24+
@pytest.fixture(autouse=True)
25+
def reset_factory_state():
26+
MockAgentFactory._agent = None
27+
MockAgentFactory._created = False
28+
MockAgentFactory._deleted = False
29+
yield
30+
MockAgentFactory._agent = None
31+
32+
33+
@pytest.mark.asyncio
34+
async def test_get_agent_creates_singleton():
35+
# Agent should be None initially
36+
assert MockAgentFactory._agent is None
37+
38+
result1 = await MockAgentFactory.get_agent()
39+
result2 = await MockAgentFactory.get_agent()
40+
41+
# Should be the same object
42+
assert result1 is result2
43+
assert MockAgentFactory._created is True
44+
assert MockAgentFactory._agent == {"agent": "mock-agent"}
45+
46+
47+
@pytest.mark.asyncio
48+
async def test_delete_agent_removes_singleton():
49+
# Set initial agent
50+
await MockAgentFactory.get_agent()
51+
assert MockAgentFactory._agent is not None
52+
53+
await MockAgentFactory.delete_agent()
54+
55+
assert MockAgentFactory._agent is None
56+
assert MockAgentFactory._deleted is True
57+
58+
59+
@pytest.mark.asyncio
60+
async def test_delete_agent_does_nothing_if_none():
61+
# Agent is None
62+
await MockAgentFactory.delete_agent()
63+
64+
assert MockAgentFactory._agent is None
65+
assert MockAgentFactory._deleted is False
66+
67+
68+
@pytest.mark.asyncio
69+
async def test_thread_safety_of_get_agent(monkeypatch):
70+
# Patch create_agent to delay and track calls
71+
call_count = 0
72+
73+
async def slow_create_agent(config):
74+
nonlocal call_count
75+
call_count += 1
76+
await asyncio.sleep(0.1)
77+
return {"agent": "thread-safe"}
78+
79+
monkeypatch.setattr(MockAgentFactory, "create_agent", slow_create_agent)
80+
81+
# Run get_agent concurrently
82+
results = await asyncio.gather(
83+
MockAgentFactory.get_agent(),
84+
MockAgentFactory.get_agent(),
85+
MockAgentFactory.get_agent()
86+
)
87+
88+
# All should return the same instance
89+
assert all(result == {"agent": "thread-safe"} for result in results)
90+
assert call_count == 1 # Only one creation
Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import pytest
2-
import asyncio
3-
from unittest.mock import patch, MagicMock, AsyncMock
4-
2+
from unittest.mock import patch, MagicMock
53
from agents.search_agent_factory import SearchAgentFactory
64

75

@@ -13,48 +11,49 @@ def reset_search_agent_factory():
1311

1412

1513
@pytest.mark.asyncio
16-
@patch("agents.search_agent_factory.Config", autospec=True)
1714
@patch("agents.search_agent_factory.DefaultAzureCredential", autospec=True)
1815
@patch("agents.search_agent_factory.AIProjectClient", autospec=True)
1916
@patch("agents.search_agent_factory.AzureAISearchTool", autospec=True)
20-
async def test_get_agent_creates_new_instance(
21-
mock_search_tool,
22-
mock_project_client_class,
23-
mock_credential,
24-
mock_config_class
17+
async def test_create_agent_creates_new_instance(
18+
mock_search_tool_cls,
19+
mock_project_client_cls,
20+
mock_credential_cls
2521
):
2622
# Mock config
2723
mock_config = MagicMock()
2824
mock_config.ai_project_endpoint = "https://fake-endpoint"
2925
mock_config.azure_ai_search_connection_name = "fake-connection"
3026
mock_config.azure_ai_search_index = "fake-index"
3127
mock_config.azure_openai_deployment_model = "fake-model"
32-
mock_config_class.return_value = mock_config
28+
mock_config.solution_name = "test-solution"
29+
mock_config.ai_project_api_version = "2025-05-01"
3330

3431
# Mock project client
3532
mock_project_client = MagicMock()
36-
mock_project_client_class.return_value = mock_project_client
33+
mock_project_client_cls.return_value = mock_project_client
3734

35+
# Mock index response
3836
mock_index = MagicMock()
3937
mock_index.name = "index-name"
4038
mock_index.version = "1"
4139
mock_project_client.indexes.create_or_update.return_value = mock_index
4240

4341
# Mock search tool
44-
mock_tool = MagicMock()
45-
mock_tool.definitions = ["tool-def"]
46-
mock_tool.resources = ["tool-res"]
47-
mock_search_tool.return_value = mock_tool
42+
mock_search_tool_instance = MagicMock()
43+
mock_search_tool_instance.definitions = ["tool-def"]
44+
mock_search_tool_instance.resources = ["tool-res"]
45+
mock_search_tool_cls.return_value = mock_search_tool_instance
4846

4947
# Mock agent
5048
mock_agent = MagicMock()
5149
mock_project_client.agents.create_agent.return_value = mock_agent
5250

53-
# Run the factory
54-
result = await SearchAgentFactory.get_agent()
51+
# Run the factory directly
52+
result = await SearchAgentFactory.create_agent(mock_config)
5553

5654
assert result["agent"] == mock_agent
5755
assert result["client"] == mock_project_client
56+
5857
mock_project_client.indexes.create_or_update.assert_called_once_with(
5958
name="project-index-fake-connection-fake-index",
6059
version="1",
@@ -74,23 +73,20 @@ async def test_get_agent_creates_new_instance(
7473

7574
@pytest.mark.asyncio
7675
async def test_get_agent_returns_existing_instance():
76+
# Setup: Already initialized
7777
SearchAgentFactory._agent = {"agent": MagicMock(), "client": MagicMock()}
7878
result = await SearchAgentFactory.get_agent()
7979
assert result == SearchAgentFactory._agent
8080

8181

8282
@pytest.mark.asyncio
8383
async def test_delete_agent_removes_agent():
84+
# Setup
8485
mock_agent = MagicMock()
8586
mock_agent.id = "mock-agent-id"
86-
8787
mock_client = MagicMock()
88-
mock_client.agents.delete_agent = MagicMock()
8988

90-
SearchAgentFactory._agent = {
91-
"agent": mock_agent,
92-
"client": mock_client
93-
}
89+
SearchAgentFactory._agent = {"agent": mock_agent, "client": mock_client}
9490

9591
await SearchAgentFactory.delete_agent()
9692

@@ -99,6 +95,8 @@ async def test_delete_agent_removes_agent():
9995

10096

10197
@pytest.mark.asyncio
102-
def test_delete_agent_does_nothing_if_none():
98+
async def test_delete_agent_does_nothing_if_none():
10399
SearchAgentFactory._agent = None
104-
SearchAgentFactory.delete_agent()
100+
await SearchAgentFactory.delete_agent()
101+
# No error should be raised, and nothing is called
102+
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import pytest
2+
from unittest.mock import patch, MagicMock, AsyncMock
3+
4+
from agents.sql_agent_factory import SQLAgentFactory
5+
6+
7+
@pytest.fixture(autouse=True)
8+
def reset_sql_agent_factory():
9+
SQLAgentFactory._agent = None
10+
yield
11+
SQLAgentFactory._agent = None
12+
13+
14+
@pytest.mark.asyncio
15+
@patch("agents.sql_agent_factory.DefaultAzureCredential", autospec=True)
16+
@patch("agents.sql_agent_factory.AIProjectClient", autospec=True)
17+
async def test_create_agent_creates_new_instance(
18+
mock_ai_client_cls,
19+
mock_credential_cls
20+
):
21+
# Mock config
22+
mock_config = MagicMock()
23+
mock_config.ai_project_endpoint = "https://test-endpoint"
24+
mock_config.ai_project_api_version = "2025-05-01"
25+
mock_config.azure_openai_deployment_model = "test-model"
26+
mock_config.solution_name = "test-solution"
27+
28+
# Mock project client
29+
mock_project_client = MagicMock()
30+
mock_ai_client_cls.return_value = mock_project_client
31+
32+
# Mock agent
33+
mock_agent = MagicMock()
34+
mock_project_client.agents.create_agent.return_value = mock_agent
35+
36+
result = await SQLAgentFactory.create_agent(mock_config)
37+
38+
assert result["agent"] == mock_agent
39+
assert result["client"] == mock_project_client
40+
41+
mock_ai_client_cls.assert_called_once_with(
42+
endpoint="https://test-endpoint",
43+
credential=mock_credential_cls.return_value,
44+
api_version="2025-05-01"
45+
)
46+
mock_project_client.agents.create_agent.assert_called_once()
47+
args, kwargs = mock_project_client.agents.create_agent.call_args
48+
assert kwargs["model"] == "test-model"
49+
assert kwargs["name"] == "KM-ChatWithSQLDatabaseAgent-test-solution"
50+
assert "Generate a valid T-SQL query" in kwargs["instructions"]
51+
52+
53+
@pytest.mark.asyncio
54+
async def test_get_agent_returns_existing_instance():
55+
SQLAgentFactory._agent = {"agent": MagicMock(), "client": MagicMock()}
56+
result = await SQLAgentFactory.get_agent()
57+
assert result == SQLAgentFactory._agent
58+
59+
60+
@pytest.mark.asyncio
61+
async def test_delete_agent_removes_agent():
62+
mock_agent = MagicMock()
63+
mock_agent.id = "agent-id"
64+
65+
mock_client = MagicMock()
66+
mock_client.agents.delete_agent = MagicMock()
67+
68+
SQLAgentFactory._agent = {
69+
"agent": mock_agent,
70+
"client": mock_client
71+
}
72+
73+
await SQLAgentFactory.delete_agent()
74+
75+
mock_client.agents.delete_agent.assert_called_once_with("agent-id")
76+
assert SQLAgentFactory._agent is None
77+
78+
79+
@pytest.mark.asyncio
80+
async def test_delete_agent_does_nothing_if_none():
81+
SQLAgentFactory._agent = None
82+
await SQLAgentFactory.delete_agent()
83+
# Nothing should raise, nothing should be called

0 commit comments

Comments
 (0)