Skip to content

Commit c88344f

Browse files
TestCase agentutils
1 parent d1ccf12 commit c88344f

File tree

1 file changed

+158
-36
lines changed

1 file changed

+158
-36
lines changed
Lines changed: 158 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
# pylint: disable=import-error, wrong-import-position, missing-module-docstring
21
import os
32
import sys
4-
from unittest.mock import MagicMock
3+
import json
54
import pytest
6-
from pydantic import ValidationError
5+
from unittest.mock import MagicMock, patch
6+
from pydantic import BaseModel
77

8-
# Environment and module setup
9-
sys.modules["azure.monitor.events.extension"] = MagicMock()
8+
# Adjust sys.path so that the project root is found.
9+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")))
1010

11+
# Set required environment variables.
1112
os.environ["COSMOSDB_ENDPOINT"] = "https://mock-endpoint"
1213
os.environ["COSMOSDB_KEY"] = "mock-key"
1314
os.environ["COSMOSDB_DATABASE"] = "mock-database"
@@ -16,39 +17,160 @@
1617
os.environ["AZURE_OPENAI_API_VERSION"] = "2023-01-01"
1718
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://mock-openai-endpoint"
1819

19-
from src.backend.agents.agentutils import extract_and_update_transition_states # noqa: F401, C0413
20-
from src.backend.models.messages import Step # noqa: F401, C0413
20+
# Patch missing azure module so that event_utils imports without error.
21+
sys.modules["azure.monitor.events.extension"] = MagicMock()
22+
23+
# --- Import the function and constant under test ---
24+
from src.backend.agents.agentutils import (
25+
extract_and_update_transition_states,
26+
common_agent_system_message,
27+
)
28+
from src.backend.models.messages import Step
29+
from autogen_core.components.models import AzureOpenAIChatCompletionClient
30+
31+
# Configure the Step model to allow extra attributes.
32+
Step.model_config["extra"] = "allow"
33+
34+
35+
# Dummy Cosmos class that records update calls.
36+
class DummyCosmosRecorder:
37+
def __init__(self):
38+
self.update_called = False
39+
40+
async def update_step(self, step):
41+
# To allow setting extra attributes, ensure __pydantic_extra__ is initialized.
42+
if step.__pydantic_extra__ is None:
43+
step.__pydantic_extra__ = {}
44+
step.__pydantic_extra__["updated_field"] = True
45+
self.update_called = True
46+
47+
48+
# Dummy model client classes to simulate LLM responses.
49+
50+
class DummyModelClient(AzureOpenAIChatCompletionClient):
51+
def __init__(self, **kwargs):
52+
# Bypass parent's __init__.
53+
pass
54+
55+
async def create(self, messages, extra_create_args=None):
56+
# Simulate a valid response that matches the expected FSMStateAndTransition schema.
57+
response_dict = {
58+
"identifiedTargetState": "State1",
59+
"identifiedTargetTransition": "Transition1"
60+
}
61+
dummy_resp = MagicMock()
62+
dummy_resp.content = json.dumps(response_dict)
63+
return dummy_resp
64+
65+
class DummyModelClientError(AzureOpenAIChatCompletionClient):
66+
def __init__(self, **kwargs):
67+
pass
68+
69+
async def create(self, messages, extra_create_args=None):
70+
raise Exception("LLM error")
2171

72+
class DummyModelClientInvalidJSON(AzureOpenAIChatCompletionClient):
73+
def __init__(self, **kwargs):
74+
pass
2275

23-
def test_step_initialization():
24-
"""Test Step initialization with valid data."""
76+
async def create(self, messages, extra_create_args=None):
77+
dummy_resp = MagicMock()
78+
dummy_resp.content = "invalid json"
79+
return dummy_resp
80+
81+
# Fixture: a dummy Step for testing.
82+
@pytest.fixture
83+
def dummy_step():
2584
step = Step(
26-
data_type="step",
27-
plan_id="test_plan",
28-
action="test_action",
29-
agent="HumanAgent",
30-
session_id="test_session",
31-
user_id="test_user",
32-
agent_reply="test_reply",
85+
id="step1",
86+
plan_id="plan1",
87+
action="Test Action",
88+
agent="HumanAgent", # Using string for simplicity.
89+
status="planned",
90+
session_id="sess1",
91+
user_id="user1",
92+
human_approval_status="requested",
3393
)
94+
# Provide a value for agent_reply.
95+
step.agent_reply = "Test reply"
96+
# Ensure __pydantic_extra__ is initialized for extra fields.
97+
step.__pydantic_extra__ = {}
98+
return step
99+
100+
# Tests for extract_and_update_transition_states
101+
102+
@pytest.mark.asyncio
103+
async def test_extract_and_update_transition_states_success(dummy_step):
104+
"""
105+
Test that extract_and_update_transition_states correctly parses the LLM response,
106+
updates the step with the expected target state and transition, and calls cosmos.update_step.
107+
"""
108+
model_client = DummyModelClient()
109+
dummy_cosmos = DummyCosmosRecorder()
110+
with patch("src.backend.agents.agentutils.CosmosBufferedChatCompletionContext", return_value=dummy_cosmos):
111+
updated_step = await extract_and_update_transition_states(dummy_step, "sess1", "user1", "anything", model_client)
112+
assert updated_step.identified_target_state == "State1"
113+
assert updated_step.identified_target_transition == "Transition1"
114+
assert dummy_cosmos.update_called is True
115+
# Check that our extra field was set.
116+
assert updated_step.__pydantic_extra__.get("updated_field") is True
117+
118+
119+
@pytest.mark.asyncio
120+
async def test_extract_and_update_transition_states_model_client_error(dummy_step):
121+
"""
122+
Test that if the model client raises an exception, it propagates.
123+
"""
124+
model_client = DummyModelClientError()
125+
with patch("src.backend.agents.agentutils.CosmosBufferedChatCompletionContext", return_value=DummyCosmosRecorder()):
126+
with pytest.raises(Exception, match="LLM error"):
127+
await extract_and_update_transition_states(dummy_step, "sess1", "user1", "anything", model_client)
128+
129+
130+
@pytest.mark.asyncio
131+
async def test_extract_and_update_transition_states_invalid_json(dummy_step):
132+
"""
133+
Test that an invalid JSON response from the model client causes an exception.
134+
"""
135+
model_client = DummyModelClientInvalidJSON()
136+
with patch("src.backend.agents.agentutils.CosmosBufferedChatCompletionContext", return_value=DummyCosmosRecorder()):
137+
with pytest.raises(Exception):
138+
await extract_and_update_transition_states(dummy_step, "sess1", "user1", "anything", model_client)
139+
140+
141+
def test_common_agent_system_message_contains_delivery_address():
142+
"""
143+
Test that the common_agent_system_message constant contains instructions regarding the delivery address.
144+
"""
145+
assert "delivery address" in common_agent_system_message
146+
147+
148+
@pytest.mark.asyncio
149+
async def test_extract_and_update_transition_states_no_agent_reply(dummy_step):
150+
"""
151+
Test the behavior when step.agent_reply is empty.
152+
"""
153+
dummy_step.agent_reply = ""
154+
# Ensure extra dict is initialized.
155+
dummy_step.__pydantic_extra__ = {}
156+
model_client = DummyModelClient()
157+
with patch("src.backend.agents.agentutils.CosmosBufferedChatCompletionContext", return_value=DummyCosmosRecorder()):
158+
updated_step = await extract_and_update_transition_states(dummy_step, "sess1", "user1", "anything", model_client)
159+
# Even with an empty agent_reply, our dummy client returns the same valid JSON.
160+
assert updated_step.identified_target_state == "State1"
161+
assert updated_step.identified_target_transition == "Transition1"
162+
34163

35-
assert step.data_type == "step"
36-
assert step.plan_id == "test_plan"
37-
assert step.action == "test_action"
38-
assert step.agent == "HumanAgent"
39-
assert step.session_id == "test_session"
40-
assert step.user_id == "test_user"
41-
assert step.agent_reply == "test_reply"
42-
assert step.status == "planned"
43-
assert step.human_approval_status == "requested"
44-
45-
46-
def test_step_missing_required_fields():
47-
"""Test Step initialization with missing required fields."""
48-
with pytest.raises(ValidationError):
49-
Step(
50-
data_type="step",
51-
action="test_action",
52-
agent="test_agent",
53-
session_id="test_session",
54-
)
164+
def test_dummy_json_parsing():
165+
"""
166+
Test that the JSON parsing in extract_and_update_transition_states works for valid JSON.
167+
"""
168+
json_str = '{"identifiedTargetState": "TestState", "identifiedTargetTransition": "TestTransition"}'
169+
data = json.loads(json_str)
170+
class DummySchema(BaseModel):
171+
identifiedTargetState: str
172+
identifiedTargetTransition: str
173+
schema = DummySchema(**data)
174+
assert schema.identifiedTargetState == "TestState"
175+
assert schema.identifiedTargetTransition == "TestTransition"
176+

0 commit comments

Comments
 (0)