1- # src/backend/tests/agents/test_agentutils.py
1+ # pylint: disable=import-error, wrong-import-position, missing-module-docstring
22import os
33import sys
4- import json
4+ from unittest . mock import MagicMock
55import pytest
6- from unittest .mock import MagicMock , patch
7- from pydantic import BaseModel
6+ from pydantic import ValidationError
87
9- # Adjust sys.path so that the project root is found.
10- sys .path . insert ( 0 , os . path . abspath ( os . path . join ( os . path . dirname ( __file__ ), "../../../" )) )
8+ # Environment and module setup
9+ sys .modules [ "azure.monitor.events.extension" ] = MagicMock ( )
1110
12- # Set required environment variables.
1311os .environ ["COSMOSDB_ENDPOINT" ] = "https://mock-endpoint"
1412os .environ ["COSMOSDB_KEY" ] = "mock-key"
1513os .environ ["COSMOSDB_DATABASE" ] = "mock-database"
1816os .environ ["AZURE_OPENAI_API_VERSION" ] = "2023-01-01"
1917os .environ ["AZURE_OPENAI_ENDPOINT" ] = "https://mock-openai-endpoint"
2018
21- # Patch missing azure module so that event_utils imports without error.
22- sys .modules ["azure.monitor.events.extension" ] = MagicMock ()
23-
24- # --- Import the function and constant under test ---
25- from src .backend .agents .agentutils import (
26- extract_and_update_transition_states ,
27- common_agent_system_message ,
28- )
29- from src .backend .models .messages import Step
30- from autogen_core .components .models import AzureOpenAIChatCompletionClient
31-
32- # Configure the Step model to allow extra attributes.
33- Step .model_config ["extra" ] = "allow"
34-
35-
36- # Dummy Cosmos class that records update calls.
37- class DummyCosmosRecorder :
38- def __init__ (self ):
39- self .update_called = False
40-
41- async def update_step (self , step ):
42- # To allow setting extra attributes, ensure __pydantic_extra__ is initialized.
43- if step .__pydantic_extra__ is None :
44- step .__pydantic_extra__ = {}
45- step .__pydantic_extra__ ["updated_field" ] = True
46- self .update_called = True
47-
48-
49- # Dummy model client classes to simulate LLM responses.
50-
51- class DummyModelClient (AzureOpenAIChatCompletionClient ):
52- def __init__ (self , ** kwargs ):
53- # Bypass parent's __init__.
54- pass
55-
56- async def create (self , messages , extra_create_args = None ):
57- # Simulate a valid response that matches the expected FSMStateAndTransition schema.
58- response_dict = {
59- "identifiedTargetState" : "State1" ,
60- "identifiedTargetTransition" : "Transition1"
61- }
62- dummy_resp = MagicMock ()
63- dummy_resp .content = json .dumps (response_dict )
64- return dummy_resp
65-
66- class DummyModelClientError (AzureOpenAIChatCompletionClient ):
67- def __init__ (self , ** kwargs ):
68- pass
69-
70- async def create (self , messages , extra_create_args = None ):
71- raise Exception ("LLM error" )
19+ from src .backend .agents .agentutils import extract_and_update_transition_states # noqa: F401, C0413
20+ from src .backend .models .messages import Step # noqa: F401, C0413
7221
73- class DummyModelClientInvalidJSON (AzureOpenAIChatCompletionClient ):
74- def __init__ (self , ** kwargs ):
75- pass
7622
77- async def create (self , messages , extra_create_args = None ):
78- dummy_resp = MagicMock ()
79- dummy_resp .content = "invalid json"
80- return dummy_resp
81-
82- # Fixture: a dummy Step for testing.
83- @pytest .fixture
84- def dummy_step ():
23+ def test_step_initialization ():
24+ """Test Step initialization with valid data."""
8525 step = Step (
86- id = "step1" ,
87- plan_id = "plan1" ,
88- action = "Test Action" ,
89- agent = "HumanAgent" , # Using string for simplicity.
90- status = "planned" ,
91- session_id = "sess1" ,
92- user_id = "user1" ,
93- human_approval_status = "requested" ,
26+ data_type = "step" ,
27+ plan_id = "test_plan" ,
28+ action = "test_action" ,
29+ agent = "HumanAgent" ,
30+ session_id = "test_session" ,
31+ user_id = "test_user" ,
32+ agent_reply = "test_reply" ,
9433 )
95- # Provide a value for agent_reply.
96- step .agent_reply = "Test reply"
97- # Ensure __pydantic_extra__ is initialized for extra fields.
98- step .__pydantic_extra__ = {}
99- return step
100-
101- # Tests for extract_and_update_transition_states
102-
103- @pytest .mark .asyncio
104- async def test_extract_and_update_transition_states_success (dummy_step ):
105- """
106- Test that extract_and_update_transition_states correctly parses the LLM response,
107- updates the step with the expected target state and transition, and calls cosmos.update_step.
108- """
109- model_client = DummyModelClient ()
110- dummy_cosmos = DummyCosmosRecorder ()
111- with patch ("src.backend.agents.agentutils.CosmosBufferedChatCompletionContext" , return_value = dummy_cosmos ):
112- updated_step = await extract_and_update_transition_states (dummy_step , "sess1" , "user1" , "anything" , model_client )
113- assert updated_step .identified_target_state == "State1"
114- assert updated_step .identified_target_transition == "Transition1"
115- assert dummy_cosmos .update_called is True
116- # Check that our extra field was set.
117- assert updated_step .__pydantic_extra__ .get ("updated_field" ) is True
118-
119-
120- @pytest .mark .asyncio
121- async def test_extract_and_update_transition_states_model_client_error (dummy_step ):
122- """
123- Test that if the model client raises an exception, it propagates.
124- """
125- model_client = DummyModelClientError ()
126- with patch ("src.backend.agents.agentutils.CosmosBufferedChatCompletionContext" , return_value = DummyCosmosRecorder ()):
127- with pytest .raises (Exception , match = "LLM error" ):
128- await extract_and_update_transition_states (dummy_step , "sess1" , "user1" , "anything" , model_client )
129-
130-
131- @pytest .mark .asyncio
132- async def test_extract_and_update_transition_states_invalid_json (dummy_step ):
133- """
134- Test that an invalid JSON response from the model client causes an exception.
135- """
136- model_client = DummyModelClientInvalidJSON ()
137- with patch ("src.backend.agents.agentutils.CosmosBufferedChatCompletionContext" , return_value = DummyCosmosRecorder ()):
138- with pytest .raises (Exception ):
139- await extract_and_update_transition_states (dummy_step , "sess1" , "user1" , "anything" , model_client )
140-
141-
142- def test_common_agent_system_message_contains_delivery_address ():
143- """
144- Test that the common_agent_system_message constant contains instructions regarding the delivery address.
145- """
146- assert "delivery address" in common_agent_system_message
147-
148-
149- @pytest .mark .asyncio
150- async def test_extract_and_update_transition_states_no_agent_reply (dummy_step ):
151- """
152- Test the behavior when step.agent_reply is empty.
153- """
154- dummy_step .agent_reply = ""
155- # Ensure extra dict is initialized.
156- dummy_step .__pydantic_extra__ = {}
157- model_client = DummyModelClient ()
158- with patch ("src.backend.agents.agentutils.CosmosBufferedChatCompletionContext" , return_value = DummyCosmosRecorder ()):
159- updated_step = await extract_and_update_transition_states (dummy_step , "sess1" , "user1" , "anything" , model_client )
160- # Even with an empty agent_reply, our dummy client returns the same valid JSON.
161- assert updated_step .identified_target_state == "State1"
162- assert updated_step .identified_target_transition == "Transition1"
163-
16434
165- def test_dummy_json_parsing ():
166- """
167- Test that the JSON parsing in extract_and_update_transition_states works for valid JSON.
168- """
169- json_str = '{"identifiedTargetState": "TestState", "identifiedTargetTransition": "TestTransition"}'
170- data = json .loads (json_str )
171- class DummySchema (BaseModel ):
172- identifiedTargetState : str
173- identifiedTargetTransition : str
174- schema = DummySchema (** data )
175- assert schema .identifiedTargetState == "TestState"
176- assert schema .identifiedTargetTransition == "TestTransition"
35+ assert step .data_type == "step"
36+ assert step .plan_id == "test_plan"
37+ assert step .action == "test_action"
38+ assert step .agent == "HumanAgent"
39+ assert step .session_id == "test_session"
40+ assert step .user_id == "test_user"
41+ assert step .agent_reply == "test_reply"
42+ assert step .status == "planned"
43+ assert step .human_approval_status == "requested"
44+
45+
46+ def test_step_missing_required_fields ():
47+ """Test Step initialization with missing required fields."""
48+ with pytest .raises (ValidationError ):
49+ Step (
50+ data_type = "step" ,
51+ action = "test_action" ,
52+ agent = "test_agent" ,
53+ session_id = "test_session" ,
54+ )
55+
0 commit comments