Skip to content

Commit bceb635

Browse files
author
Harmanpreet Kaur
committed
added test_base_agent file
1 parent f84ad1f commit bceb635

File tree

1 file changed

+166
-0
lines changed

1 file changed

+166
-0
lines changed
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# pylint: disable=import-error, wrong-import-position, missing-module-docstring
2+
import os
3+
import sys
4+
from unittest.mock import MagicMock, AsyncMock, patch
5+
import pytest
6+
from contextlib import contextmanager
7+
8+
# Mocking necessary modules and environment variables
9+
sys.modules["azure.monitor.events.extension"] = MagicMock()
10+
11+
# Mocking environment variables
12+
os.environ["COSMOSDB_ENDPOINT"] = "https://mock-endpoint"
13+
os.environ["COSMOSDB_KEY"] = "mock-key"
14+
os.environ["COSMOSDB_DATABASE"] = "mock-database"
15+
os.environ["COSMOSDB_CONTAINER"] = "mock-container"
16+
os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"] = "mock-deployment-name"
17+
os.environ["AZURE_OPENAI_API_VERSION"] = "2023-01-01"
18+
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://mock-openai-endpoint"
19+
20+
# Importing the module to test
21+
from src.backend.agents.base_agent import BaseAgent
22+
from src.backend.models.messages import ActionRequest, Step, StepStatus, ActionResponse, AgentMessage
23+
from autogen_core.base import AgentId
24+
from autogen_core.components.models import AssistantMessage, UserMessage
25+
26+
# Context manager for setting up mocks
27+
@contextmanager
28+
def mock_context():
29+
mock_runtime = MagicMock()
30+
with patch("autogen_core.base._agent_instantiation.AgentInstantiationContext.AGENT_INSTANTIATION_CONTEXT_VAR") as mock_context_var:
31+
mock_context_instance = MagicMock()
32+
mock_context_var.get.return_value = mock_context_instance
33+
mock_context_instance.set.return_value = None
34+
yield mock_runtime
35+
36+
@pytest.fixture
37+
def mock_dependencies():
38+
model_client = MagicMock()
39+
model_context = MagicMock()
40+
tools = [MagicMock(schema="tool_schema")]
41+
tool_agent_id = MagicMock()
42+
return {
43+
"model_client": model_client,
44+
"model_context": model_context,
45+
"tools": tools,
46+
"tool_agent_id": tool_agent_id,
47+
}
48+
49+
@pytest.fixture
50+
def base_agent(mock_dependencies):
51+
with mock_context():
52+
return BaseAgent(
53+
agent_name="test_agent",
54+
model_client=mock_dependencies["model_client"],
55+
session_id="test_session",
56+
user_id="test_user",
57+
model_context=mock_dependencies["model_context"],
58+
tools=mock_dependencies["tools"],
59+
tool_agent_id=mock_dependencies["tool_agent_id"],
60+
system_message="This is a system message.",
61+
)
62+
63+
def test_save_state(base_agent, mock_dependencies):
64+
mock_dependencies["model_context"].save_state = MagicMock(return_value={"state_key": "state_value"})
65+
state = base_agent.save_state()
66+
assert state == {"memory": {"state_key": "state_value"}}
67+
68+
def test_load_state(base_agent, mock_dependencies):
69+
mock_dependencies["model_context"].load_state = MagicMock()
70+
state = {"memory": {"state_key": "state_value"}}
71+
base_agent.load_state(state)
72+
mock_dependencies["model_context"].load_state.assert_called_once_with({"state_key": "state_value"})
73+
74+
@pytest.mark.asyncio
75+
async def test_handle_action_request_error(base_agent, mock_dependencies):
76+
"""Test handle_action_request when tool_agent_caller_loop raises an error."""
77+
# Mocking a Step object
78+
step = Step(
79+
id="step_1",
80+
status=StepStatus.approved,
81+
human_feedback="feedback",
82+
agent_reply="",
83+
plan_id="plan_id",
84+
action="action",
85+
agent="HumanAgent",
86+
session_id="session_id",
87+
user_id="user_id",
88+
)
89+
90+
# Mocking the model context methods
91+
mock_dependencies["model_context"].get_step = AsyncMock(return_value=step)
92+
mock_dependencies["model_context"].add_item = AsyncMock()
93+
94+
# Mock tool_agent_caller_loop to raise an exception
95+
with patch("src.backend.agents.base_agent.tool_agent_caller_loop", AsyncMock(side_effect=Exception("Mock error"))):
96+
# Define the ActionRequest message
97+
message = ActionRequest(
98+
step_id="step_1",
99+
session_id="test_session",
100+
action="test_action",
101+
plan_id="plan_id",
102+
agent="HumanAgent",
103+
)
104+
ctx = MagicMock()
105+
106+
# Call handle_action_request and capture exception
107+
with pytest.raises(ValueError) as excinfo:
108+
await base_agent.handle_action_request(message, ctx)
109+
110+
# Assert that the exception matches the expected ValueError
111+
assert "Return type <class 'NoneType'> not in return types" in str(excinfo.value), (
112+
"Expected ValueError due to NoneType return, but got a different exception."
113+
)
114+
115+
@pytest.mark.asyncio
116+
async def test_handle_action_request_success(base_agent, mock_dependencies):
117+
"""Test handle_action_request with a successful tool_agent_caller_loop."""
118+
# Update Step with a valid agent enum value
119+
step = Step(
120+
id="step_1",
121+
status=StepStatus.approved,
122+
human_feedback="feedback",
123+
agent_reply="",
124+
plan_id="plan_id",
125+
action="action",
126+
agent="HumanAgent",
127+
session_id="session_id",
128+
user_id="user_id"
129+
)
130+
mock_dependencies["model_context"].get_step = AsyncMock(return_value=step)
131+
mock_dependencies["model_context"].update_step = AsyncMock()
132+
mock_dependencies["model_context"].add_item = AsyncMock()
133+
134+
# Mock the tool_agent_caller_loop to return a result
135+
with patch("src.backend.agents.base_agent.tool_agent_caller_loop", new=AsyncMock(return_value=[MagicMock(content="result")])):
136+
# Mock the publish_message method to be awaitable
137+
base_agent._runtime.publish_message = AsyncMock()
138+
139+
message = ActionRequest(
140+
step_id="step_1",
141+
session_id="test_session",
142+
action="test_action",
143+
plan_id="plan_id",
144+
agent="HumanAgent"
145+
)
146+
ctx = MagicMock()
147+
148+
# Call the method being tested
149+
response = await base_agent.handle_action_request(message, ctx)
150+
151+
# Assertions to ensure the response is correct
152+
assert response.status == StepStatus.completed
153+
assert response.result == "result"
154+
assert response.plan_id == "plan_id" # Validate plan_id
155+
assert response.session_id == "test_session" # Validate session_id
156+
157+
# Ensure publish_message was called
158+
base_agent._runtime.publish_message.assert_awaited_once_with(
159+
response,
160+
AgentId(type="group_chat_manager", key="test_session"),
161+
sender=base_agent.id,
162+
cancellation_token=None
163+
)
164+
165+
# Ensure the step was updated
166+
mock_dependencies["model_context"].update_step.assert_called_once_with(step)

0 commit comments

Comments
 (0)