1- # pylint: disable=import-error, wrong-import-position, missing-module-docstring
21import os
32import sys
4- from unittest . mock import MagicMock
3+ import json
54import pytest
6- from pydantic import ValidationError
5+ from unittest .mock import MagicMock , patch
6+ from pydantic import BaseModel
77
8- # Environment and module setup
9- sys .modules [ "azure.monitor.events.extension" ] = MagicMock ( )
8+ # Adjust sys.path so that the project root is found.
9+ sys .path . insert ( 0 , os . path . abspath ( os . path . join ( os . path . dirname ( __file__ ), "../../../" )) )
1010
11+ # Set required environment variables.
1112os .environ ["COSMOSDB_ENDPOINT" ] = "https://mock-endpoint"
1213os .environ ["COSMOSDB_KEY" ] = "mock-key"
1314os .environ ["COSMOSDB_DATABASE" ] = "mock-database"
1617os .environ ["AZURE_OPENAI_API_VERSION" ] = "2023-01-01"
1718os .environ ["AZURE_OPENAI_ENDPOINT" ] = "https://mock-openai-endpoint"
1819
19- from src .backend .agents .agentutils import extract_and_update_transition_states # noqa: F401, C0413
20- from src .backend .models .messages import Step # noqa: F401, C0413
20+ # Patch missing azure module so that event_utils imports without error.
21+ sys .modules ["azure.monitor.events.extension" ] = MagicMock ()
22+
23+ # --- Import the function and constant under test ---
24+ from src .backend .agents .agentutils import (
25+ extract_and_update_transition_states ,
26+ common_agent_system_message ,
27+ )
28+ from src .backend .models .messages import Step
29+ from autogen_core .components .models import AzureOpenAIChatCompletionClient
30+
31+ # Configure the Step model to allow extra attributes.
32+ Step .model_config ["extra" ] = "allow"
33+
34+
35+ # Dummy Cosmos class that records update calls.
36+ class DummyCosmosRecorder :
37+ def __init__ (self ):
38+ self .update_called = False
39+
40+ async def update_step (self , step ):
41+ # To allow setting extra attributes, ensure __pydantic_extra__ is initialized.
42+ if step .__pydantic_extra__ is None :
43+ step .__pydantic_extra__ = {}
44+ step .__pydantic_extra__ ["updated_field" ] = True
45+ self .update_called = True
46+
47+
48+ # Dummy model client classes to simulate LLM responses.
49+
50+ class DummyModelClient (AzureOpenAIChatCompletionClient ):
51+ def __init__ (self , ** kwargs ):
52+ # Bypass parent's __init__.
53+ pass
54+
55+ async def create (self , messages , extra_create_args = None ):
56+ # Simulate a valid response that matches the expected FSMStateAndTransition schema.
57+ response_dict = {
58+ "identifiedTargetState" : "State1" ,
59+ "identifiedTargetTransition" : "Transition1"
60+ }
61+ dummy_resp = MagicMock ()
62+ dummy_resp .content = json .dumps (response_dict )
63+ return dummy_resp
64+
65+ class DummyModelClientError (AzureOpenAIChatCompletionClient ):
66+ def __init__ (self , ** kwargs ):
67+ pass
68+
69+ async def create (self , messages , extra_create_args = None ):
70+ raise Exception ("LLM error" )
2171
72+ class DummyModelClientInvalidJSON (AzureOpenAIChatCompletionClient ):
73+ def __init__ (self , ** kwargs ):
74+ pass
2275
23- def test_step_initialization ():
24- """Test Step initialization with valid data."""
76+ async def create (self , messages , extra_create_args = None ):
77+ dummy_resp = MagicMock ()
78+ dummy_resp .content = "invalid json"
79+ return dummy_resp
80+
81+ # Fixture: a dummy Step for testing.
82+ @pytest .fixture
83+ def dummy_step ():
2584 step = Step (
26- data_type = "step" ,
27- plan_id = "test_plan" ,
28- action = "test_action" ,
29- agent = "HumanAgent" ,
30- session_id = "test_session" ,
31- user_id = "test_user" ,
32- agent_reply = "test_reply" ,
85+ id = "step1" ,
86+ plan_id = "plan1" ,
87+ action = "Test Action" ,
88+ agent = "HumanAgent" , # Using string for simplicity.
89+ status = "planned" ,
90+ session_id = "sess1" ,
91+ user_id = "user1" ,
92+ human_approval_status = "requested" ,
3393 )
94+ # Provide a value for agent_reply.
95+ step .agent_reply = "Test reply"
96+ # Ensure __pydantic_extra__ is initialized for extra fields.
97+ step .__pydantic_extra__ = {}
98+ return step
99+
100+ # Tests for extract_and_update_transition_states
101+
102+ @pytest .mark .asyncio
103+ async def test_extract_and_update_transition_states_success (dummy_step ):
104+ """
105+ Test that extract_and_update_transition_states correctly parses the LLM response,
106+ updates the step with the expected target state and transition, and calls cosmos.update_step.
107+ """
108+ model_client = DummyModelClient ()
109+ dummy_cosmos = DummyCosmosRecorder ()
110+ with patch ("src.backend.agents.agentutils.CosmosBufferedChatCompletionContext" , return_value = dummy_cosmos ):
111+ updated_step = await extract_and_update_transition_states (dummy_step , "sess1" , "user1" , "anything" , model_client )
112+ assert updated_step .identified_target_state == "State1"
113+ assert updated_step .identified_target_transition == "Transition1"
114+ assert dummy_cosmos .update_called is True
115+ # Check that our extra field was set.
116+ assert updated_step .__pydantic_extra__ .get ("updated_field" ) is True
117+
118+
119+ @pytest .mark .asyncio
120+ async def test_extract_and_update_transition_states_model_client_error (dummy_step ):
121+ """
122+ Test that if the model client raises an exception, it propagates.
123+ """
124+ model_client = DummyModelClientError ()
125+ with patch ("src.backend.agents.agentutils.CosmosBufferedChatCompletionContext" , return_value = DummyCosmosRecorder ()):
126+ with pytest .raises (Exception , match = "LLM error" ):
127+ await extract_and_update_transition_states (dummy_step , "sess1" , "user1" , "anything" , model_client )
128+
129+
130+ @pytest .mark .asyncio
131+ async def test_extract_and_update_transition_states_invalid_json (dummy_step ):
132+ """
133+ Test that an invalid JSON response from the model client causes an exception.
134+ """
135+ model_client = DummyModelClientInvalidJSON ()
136+ with patch ("src.backend.agents.agentutils.CosmosBufferedChatCompletionContext" , return_value = DummyCosmosRecorder ()):
137+ with pytest .raises (Exception ):
138+ await extract_and_update_transition_states (dummy_step , "sess1" , "user1" , "anything" , model_client )
139+
140+
141+ def test_common_agent_system_message_contains_delivery_address ():
142+ """
143+ Test that the common_agent_system_message constant contains instructions regarding the delivery address.
144+ """
145+ assert "delivery address" in common_agent_system_message
146+
147+
148+ @pytest .mark .asyncio
149+ async def test_extract_and_update_transition_states_no_agent_reply (dummy_step ):
150+ """
151+ Test the behavior when step.agent_reply is empty.
152+ """
153+ dummy_step .agent_reply = ""
154+ # Ensure extra dict is initialized.
155+ dummy_step .__pydantic_extra__ = {}
156+ model_client = DummyModelClient ()
157+ with patch ("src.backend.agents.agentutils.CosmosBufferedChatCompletionContext" , return_value = DummyCosmosRecorder ()):
158+ updated_step = await extract_and_update_transition_states (dummy_step , "sess1" , "user1" , "anything" , model_client )
159+ # Even with an empty agent_reply, our dummy client returns the same valid JSON.
160+ assert updated_step .identified_target_state == "State1"
161+ assert updated_step .identified_target_transition == "Transition1"
162+
34163
35- assert step .data_type == "step"
36- assert step .plan_id == "test_plan"
37- assert step .action == "test_action"
38- assert step .agent == "HumanAgent"
39- assert step .session_id == "test_session"
40- assert step .user_id == "test_user"
41- assert step .agent_reply == "test_reply"
42- assert step .status == "planned"
43- assert step .human_approval_status == "requested"
44-
45-
46- def test_step_missing_required_fields ():
47- """Test Step initialization with missing required fields."""
48- with pytest .raises (ValidationError ):
49- Step (
50- data_type = "step" ,
51- action = "test_action" ,
52- agent = "test_agent" ,
53- session_id = "test_session" ,
54- )
164+ def test_dummy_json_parsing ():
165+ """
166+ Test that the JSON parsing in extract_and_update_transition_states works for valid JSON.
167+ """
168+ json_str = '{"identifiedTargetState": "TestState", "identifiedTargetTransition": "TestTransition"}'
169+ data = json .loads (json_str )
170+ class DummySchema (BaseModel ):
171+ identifiedTargetState : str
172+ identifiedTargetTransition : str
173+ schema = DummySchema (** data )
174+ assert schema .identifiedTargetState == "TestState"
175+ assert schema .identifiedTargetTransition == "TestTransition"
176+
0 commit comments