|
1 | | -# pylint: disable=import-error, wrong-import-position, missing-module-docstring |
| 1 | +# src/backend/tests/agents/test_agentutils.py |
2 | 2 | import os |
3 | 3 | import sys |
4 | | -from unittest.mock import MagicMock |
| 4 | +import json |
5 | 5 | import pytest |
6 | | -from pydantic import ValidationError |
| 6 | +from unittest.mock import MagicMock, patch |
| 7 | +from pydantic import BaseModel |
7 | 8 |
|
8 | | -# Environment and module setup |
9 | | -sys.modules["azure.monitor.events.extension"] = MagicMock() |
| 9 | +# Adjust sys.path so that the project root is found. |
| 10 | +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) |
10 | 11 |
|
| 12 | +# Set required environment variables. |
11 | 13 | os.environ["COSMOSDB_ENDPOINT"] = "https://mock-endpoint" |
12 | 14 | os.environ["COSMOSDB_KEY"] = "mock-key" |
13 | 15 | os.environ["COSMOSDB_DATABASE"] = "mock-database" |
|
16 | 18 | os.environ["AZURE_OPENAI_API_VERSION"] = "2023-01-01" |
17 | 19 | os.environ["AZURE_OPENAI_ENDPOINT"] = "https://mock-openai-endpoint" |
18 | 20 |
|
19 | | -from src.backend.agents.agentutils import extract_and_update_transition_states # noqa: F401, C0413 |
20 | | -from src.backend.models.messages import Step # noqa: F401, C0413 |
| 21 | +# Patch missing azure module so that event_utils imports without error. |
| 22 | +sys.modules["azure.monitor.events.extension"] = MagicMock() |
| 23 | + |
| 24 | +# --- Import the function and constant under test --- |
| 25 | +from src.backend.agents.agentutils import ( |
| 26 | + extract_and_update_transition_states, |
| 27 | + common_agent_system_message, |
| 28 | +) |
| 29 | +from src.backend.models.messages import Step |
| 30 | +from autogen_core.components.models import AzureOpenAIChatCompletionClient |
| 31 | + |
| 32 | +# Configure the Step model to allow extra attributes. |
| 33 | +Step.model_config["extra"] = "allow" |
| 34 | + |
| 35 | + |
| 36 | +# Dummy Cosmos class that records update calls. |
| 37 | +class DummyCosmosRecorder: |
| 38 | + def __init__(self): |
| 39 | + self.update_called = False |
| 40 | + |
| 41 | + async def update_step(self, step): |
| 42 | + # To allow setting extra attributes, ensure __pydantic_extra__ is initialized. |
| 43 | + if step.__pydantic_extra__ is None: |
| 44 | + step.__pydantic_extra__ = {} |
| 45 | + step.__pydantic_extra__["updated_field"] = True |
| 46 | + self.update_called = True |
| 47 | + |
| 48 | + |
| 49 | +# Dummy model client classes to simulate LLM responses. |
| 50 | + |
| 51 | +class DummyModelClient(AzureOpenAIChatCompletionClient): |
| 52 | + def __init__(self, **kwargs): |
| 53 | + # Bypass parent's __init__. |
| 54 | + pass |
| 55 | + |
| 56 | + async def create(self, messages, extra_create_args=None): |
| 57 | + # Simulate a valid response that matches the expected FSMStateAndTransition schema. |
| 58 | + response_dict = { |
| 59 | + "identifiedTargetState": "State1", |
| 60 | + "identifiedTargetTransition": "Transition1" |
| 61 | + } |
| 62 | + dummy_resp = MagicMock() |
| 63 | + dummy_resp.content = json.dumps(response_dict) |
| 64 | + return dummy_resp |
| 65 | + |
| 66 | +class DummyModelClientError(AzureOpenAIChatCompletionClient): |
| 67 | + def __init__(self, **kwargs): |
| 68 | + pass |
| 69 | + |
| 70 | + async def create(self, messages, extra_create_args=None): |
| 71 | + raise Exception("LLM error") |
21 | 72 |
|
| 73 | +class DummyModelClientInvalidJSON(AzureOpenAIChatCompletionClient): |
| 74 | + def __init__(self, **kwargs): |
| 75 | + pass |
22 | 76 |
|
23 | | -def test_step_initialization(): |
24 | | - """Test Step initialization with valid data.""" |
| 77 | + async def create(self, messages, extra_create_args=None): |
| 78 | + dummy_resp = MagicMock() |
| 79 | + dummy_resp.content = "invalid json" |
| 80 | + return dummy_resp |
| 81 | + |
| 82 | +# Fixture: a dummy Step for testing. |
| 83 | +@pytest.fixture |
| 84 | +def dummy_step(): |
25 | 85 | step = Step( |
26 | | - data_type="step", |
27 | | - plan_id="test_plan", |
28 | | - action="test_action", |
29 | | - agent="HumanAgent", |
30 | | - session_id="test_session", |
31 | | - user_id="test_user", |
32 | | - agent_reply="test_reply", |
| 86 | + id="step1", |
| 87 | + plan_id="plan1", |
| 88 | + action="Test Action", |
| 89 | + agent="HumanAgent", # Using string for simplicity. |
| 90 | + status="planned", |
| 91 | + session_id="sess1", |
| 92 | + user_id="user1", |
| 93 | + human_approval_status="requested", |
33 | 94 | ) |
| 95 | + # Provide a value for agent_reply. |
| 96 | + step.agent_reply = "Test reply" |
| 97 | + # Ensure __pydantic_extra__ is initialized for extra fields. |
| 98 | + step.__pydantic_extra__ = {} |
| 99 | + return step |
| 100 | + |
| 101 | +# Tests for extract_and_update_transition_states |
| 102 | + |
| 103 | +@pytest.mark.asyncio |
| 104 | +async def test_extract_and_update_transition_states_success(dummy_step): |
| 105 | + """ |
| 106 | + Test that extract_and_update_transition_states correctly parses the LLM response, |
| 107 | + updates the step with the expected target state and transition, and calls cosmos.update_step. |
| 108 | + """ |
| 109 | + model_client = DummyModelClient() |
| 110 | + dummy_cosmos = DummyCosmosRecorder() |
| 111 | + with patch("src.backend.agents.agentutils.CosmosBufferedChatCompletionContext", return_value=dummy_cosmos): |
| 112 | + updated_step = await extract_and_update_transition_states(dummy_step, "sess1", "user1", "anything", model_client) |
| 113 | + assert updated_step.identified_target_state == "State1" |
| 114 | + assert updated_step.identified_target_transition == "Transition1" |
| 115 | + assert dummy_cosmos.update_called is True |
| 116 | + # Check that our extra field was set. |
| 117 | + assert updated_step.__pydantic_extra__.get("updated_field") is True |
| 118 | + |
| 119 | + |
| 120 | +@pytest.mark.asyncio |
| 121 | +async def test_extract_and_update_transition_states_model_client_error(dummy_step): |
| 122 | + """ |
| 123 | + Test that if the model client raises an exception, it propagates. |
| 124 | + """ |
| 125 | + model_client = DummyModelClientError() |
| 126 | + with patch("src.backend.agents.agentutils.CosmosBufferedChatCompletionContext", return_value=DummyCosmosRecorder()): |
| 127 | + with pytest.raises(Exception, match="LLM error"): |
| 128 | + await extract_and_update_transition_states(dummy_step, "sess1", "user1", "anything", model_client) |
| 129 | + |
| 130 | + |
| 131 | +@pytest.mark.asyncio |
| 132 | +async def test_extract_and_update_transition_states_invalid_json(dummy_step): |
| 133 | + """ |
| 134 | + Test that an invalid JSON response from the model client causes an exception. |
| 135 | + """ |
| 136 | + model_client = DummyModelClientInvalidJSON() |
| 137 | + with patch("src.backend.agents.agentutils.CosmosBufferedChatCompletionContext", return_value=DummyCosmosRecorder()): |
| 138 | + with pytest.raises(Exception): |
| 139 | + await extract_and_update_transition_states(dummy_step, "sess1", "user1", "anything", model_client) |
| 140 | + |
| 141 | + |
| 142 | +def test_common_agent_system_message_contains_delivery_address(): |
| 143 | + """ |
| 144 | + Test that the common_agent_system_message constant contains instructions regarding the delivery address. |
| 145 | + """ |
| 146 | + assert "delivery address" in common_agent_system_message |
| 147 | + |
| 148 | + |
| 149 | +@pytest.mark.asyncio |
| 150 | +async def test_extract_and_update_transition_states_no_agent_reply(dummy_step): |
| 151 | + """ |
| 152 | + Test the behavior when step.agent_reply is empty. |
| 153 | + """ |
| 154 | + dummy_step.agent_reply = "" |
| 155 | + # Ensure extra dict is initialized. |
| 156 | + dummy_step.__pydantic_extra__ = {} |
| 157 | + model_client = DummyModelClient() |
| 158 | + with patch("src.backend.agents.agentutils.CosmosBufferedChatCompletionContext", return_value=DummyCosmosRecorder()): |
| 159 | + updated_step = await extract_and_update_transition_states(dummy_step, "sess1", "user1", "anything", model_client) |
| 160 | + # Even with an empty agent_reply, our dummy client returns the same valid JSON. |
| 161 | + assert updated_step.identified_target_state == "State1" |
| 162 | + assert updated_step.identified_target_transition == "Transition1" |
| 163 | + |
34 | 164 |
|
35 | | - assert step.data_type == "step" |
36 | | - assert step.plan_id == "test_plan" |
37 | | - assert step.action == "test_action" |
38 | | - assert step.agent == "HumanAgent" |
39 | | - assert step.session_id == "test_session" |
40 | | - assert step.user_id == "test_user" |
41 | | - assert step.agent_reply == "test_reply" |
42 | | - assert step.status == "planned" |
43 | | - assert step.human_approval_status == "requested" |
44 | | - |
45 | | - |
46 | | -def test_step_missing_required_fields(): |
47 | | - """Test Step initialization with missing required fields.""" |
48 | | - with pytest.raises(ValidationError): |
49 | | - Step( |
50 | | - data_type="step", |
51 | | - action="test_action", |
52 | | - agent="test_agent", |
53 | | - session_id="test_session", |
54 | | - ) |
| 165 | +def test_dummy_json_parsing(): |
| 166 | + """ |
| 167 | + Test that the JSON parsing in extract_and_update_transition_states works for valid JSON. |
| 168 | + """ |
| 169 | + json_str = '{"identifiedTargetState": "TestState", "identifiedTargetTransition": "TestTransition"}' |
| 170 | + data = json.loads(json_str) |
| 171 | + class DummySchema(BaseModel): |
| 172 | + identifiedTargetState: str |
| 173 | + identifiedTargetTransition: str |
| 174 | + schema = DummySchema(**data) |
| 175 | + assert schema.identifiedTargetState == "TestState" |
| 176 | + assert schema.identifiedTargetTransition == "TestTransition" |
0 commit comments