1+ # pylint: disable=import-error, wrong-import-position, missing-module-docstring
2+ import json
3+ import os
4+ import sys
5+ from unittest .mock import AsyncMock , MagicMock , patch
6+ import pytest
7+ from pydantic import ValidationError
8+
9+
10+ # Environment and module setup
11+ sys .modules ["azure.monitor.events.extension" ] = MagicMock ()
12+
13+ os .environ ["COSMOSDB_ENDPOINT" ] = "https://mock-endpoint"
14+ os .environ ["COSMOSDB_KEY" ] = "mock-key"
15+ os .environ ["COSMOSDB_DATABASE" ] = "mock-database"
16+ os .environ ["COSMOSDB_CONTAINER" ] = "mock-container"
17+ os .environ ["AZURE_OPENAI_DEPLOYMENT_NAME" ] = "mock-deployment-name"
18+ os .environ ["AZURE_OPENAI_API_VERSION" ] = "2023-01-01"
19+ os .environ ["AZURE_OPENAI_ENDPOINT" ] = "https://mock-openai-endpoint"
20+
21+ # noqa: F401 is to ignore unused import warnings (if any)
22+ from src .backend .agents .agentutils import extract_and_update_transition_states # noqa: F401, C0413
23+ from src .backend .models .messages import Step # noqa: F401, C0413
24+
25+
26+ @pytest .mark .asyncio
27+ async def test_extract_and_update_transition_states_invalid_response ():
28+ """Test handling of invalid JSON response from model client."""
29+ session_id = "test_session"
30+ user_id = "test_user"
31+ step = Step (
32+ data_type = "step" ,
33+ plan_id = "test_plan" ,
34+ action = "test_action" ,
35+ agent = "HumanAgent" ,
36+ session_id = session_id ,
37+ user_id = user_id ,
38+ agent_reply = "test_reply" ,
39+ )
40+ model_client = AsyncMock ()
41+ cosmos_mock = MagicMock ()
42+
43+ model_client .create .return_value = MagicMock (content = "invalid_json" )
44+
45+ with patch (
46+ "src.backend.context.cosmos_memory.CosmosBufferedChatCompletionContext" ,
47+ cosmos_mock ,
48+ ):
49+ with pytest .raises (json .JSONDecodeError ):
50+ await extract_and_update_transition_states (
51+ step = step ,
52+ session_id = session_id ,
53+ user_id = user_id ,
54+ planner_dynamic_or_workflow = "workflow" ,
55+ model_client = model_client ,
56+ )
57+
58+ cosmos_mock .update_step .assert_not_called ()
59+
60+
61+ @pytest .mark .asyncio
62+ async def test_extract_and_update_transition_states_validation_error ():
63+ """Test handling of a response missing required fields."""
64+ session_id = "test_session"
65+ user_id = "test_user"
66+ step = Step (
67+ data_type = "step" ,
68+ plan_id = "test_plan" ,
69+ action = "test_action" ,
70+ agent = "HumanAgent" ,
71+ session_id = session_id ,
72+ user_id = user_id ,
73+ agent_reply = "test_reply" ,
74+ )
75+ model_client = AsyncMock ()
76+ cosmos_mock = MagicMock ()
77+
78+ invalid_response = {
79+ "identifiedTargetState" : "state1"
80+ } # Missing 'identifiedTargetTransition'
81+ model_client .create .return_value = MagicMock (content = json .dumps (invalid_response ))
82+
83+ with patch (
84+ "src.backend.context.cosmos_memory.CosmosBufferedChatCompletionContext" ,
85+ cosmos_mock ,
86+ ):
87+ with pytest .raises (ValidationError ):
88+ await extract_and_update_transition_states (
89+ step = step ,
90+ session_id = session_id ,
91+ user_id = user_id ,
92+ planner_dynamic_or_workflow = "workflow" ,
93+ model_client = model_client ,
94+ )
95+
96+ cosmos_mock .update_step .assert_not_called ()
97+
98+
99+ def test_step_initialization ():
100+ """Test Step initialization with valid data."""
101+ step = Step (
102+ data_type = "step" ,
103+ plan_id = "test_plan" ,
104+ action = "test_action" ,
105+ agent = "HumanAgent" ,
106+ session_id = "test_session" ,
107+ user_id = "test_user" ,
108+ agent_reply = "test_reply" ,
109+ )
110+
111+ assert step .data_type == "step"
112+ assert step .plan_id == "test_plan"
113+ assert step .action == "test_action"
114+ assert step .agent == "HumanAgent"
115+ assert step .session_id == "test_session"
116+ assert step .user_id == "test_user"
117+ assert step .agent_reply == "test_reply"
118+ assert step .status == "planned"
119+ assert step .human_approval_status == "requested"
120+
121+
122+ def test_step_missing_required_fields ():
123+ """Test Step initialization with missing required fields."""
124+ with pytest .raises (ValidationError ):
125+ Step (
126+ data_type = "step" ,
127+ action = "test_action" ,
128+ agent = "test_agent" ,
129+ session_id = "test_session" ,
130+ )
0 commit comments