Skip to content

Commit 40aef6b

Browse files
author
Harmanpreet Kaur
committed
edited flak
1 parent 8ba7f17 commit 40aef6b

File tree

2 files changed

+131
-1
lines changed

2 files changed

+131
-1
lines changed

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
max-line-length = 88
33
extend-ignore = E501
44
exclude = .venv, frontend
5-
ignore = E203, W503, G004, G200
5+
ignore = E203, W503, G004, G200, E402
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# pylint: disable=import-error, wrong-import-position, missing-module-docstring
2+
import json
3+
import os
4+
import sys
5+
from unittest.mock import AsyncMock, MagicMock, patch
6+
import pytest
7+
from pydantic import ValidationError
8+
9+
10+
# Environment and module setup
11+
sys.modules["azure.monitor.events.extension"] = MagicMock()
12+
13+
os.environ["COSMOSDB_ENDPOINT"] = "https://mock-endpoint"
14+
os.environ["COSMOSDB_KEY"] = "mock-key"
15+
os.environ["COSMOSDB_DATABASE"] = "mock-database"
16+
os.environ["COSMOSDB_CONTAINER"] = "mock-container"
17+
os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"] = "mock-deployment-name"
18+
os.environ["AZURE_OPENAI_API_VERSION"] = "2023-01-01"
19+
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://mock-openai-endpoint"
20+
21+
# noqa: F401 is to ignore unused import warnings (if any)
22+
from src.backend.agents.agentutils import extract_and_update_transition_states # noqa: F401, C0413
23+
from src.backend.models.messages import Step # noqa: F401, C0413
24+
25+
26+
@pytest.mark.asyncio
27+
async def test_extract_and_update_transition_states_invalid_response():
28+
"""Test handling of invalid JSON response from model client."""
29+
session_id = "test_session"
30+
user_id = "test_user"
31+
step = Step(
32+
data_type="step",
33+
plan_id="test_plan",
34+
action="test_action",
35+
agent="HumanAgent",
36+
session_id=session_id,
37+
user_id=user_id,
38+
agent_reply="test_reply",
39+
)
40+
model_client = AsyncMock()
41+
cosmos_mock = MagicMock()
42+
43+
model_client.create.return_value = MagicMock(content="invalid_json")
44+
45+
with patch(
46+
"src.backend.context.cosmos_memory.CosmosBufferedChatCompletionContext",
47+
cosmos_mock,
48+
):
49+
with pytest.raises(json.JSONDecodeError):
50+
await extract_and_update_transition_states(
51+
step=step,
52+
session_id=session_id,
53+
user_id=user_id,
54+
planner_dynamic_or_workflow="workflow",
55+
model_client=model_client,
56+
)
57+
58+
cosmos_mock.update_step.assert_not_called()
59+
60+
61+
@pytest.mark.asyncio
62+
async def test_extract_and_update_transition_states_validation_error():
63+
"""Test handling of a response missing required fields."""
64+
session_id = "test_session"
65+
user_id = "test_user"
66+
step = Step(
67+
data_type="step",
68+
plan_id="test_plan",
69+
action="test_action",
70+
agent="HumanAgent",
71+
session_id=session_id,
72+
user_id=user_id,
73+
agent_reply="test_reply",
74+
)
75+
model_client = AsyncMock()
76+
cosmos_mock = MagicMock()
77+
78+
invalid_response = {
79+
"identifiedTargetState": "state1"
80+
} # Missing 'identifiedTargetTransition'
81+
model_client.create.return_value = MagicMock(content=json.dumps(invalid_response))
82+
83+
with patch(
84+
"src.backend.context.cosmos_memory.CosmosBufferedChatCompletionContext",
85+
cosmos_mock,
86+
):
87+
with pytest.raises(ValidationError):
88+
await extract_and_update_transition_states(
89+
step=step,
90+
session_id=session_id,
91+
user_id=user_id,
92+
planner_dynamic_or_workflow="workflow",
93+
model_client=model_client,
94+
)
95+
96+
cosmos_mock.update_step.assert_not_called()
97+
98+
99+
def test_step_initialization():
100+
"""Test Step initialization with valid data."""
101+
step = Step(
102+
data_type="step",
103+
plan_id="test_plan",
104+
action="test_action",
105+
agent="HumanAgent",
106+
session_id="test_session",
107+
user_id="test_user",
108+
agent_reply="test_reply",
109+
)
110+
111+
assert step.data_type == "step"
112+
assert step.plan_id == "test_plan"
113+
assert step.action == "test_action"
114+
assert step.agent == "HumanAgent"
115+
assert step.session_id == "test_session"
116+
assert step.user_id == "test_user"
117+
assert step.agent_reply == "test_reply"
118+
assert step.status == "planned"
119+
assert step.human_approval_status == "requested"
120+
121+
122+
def test_step_missing_required_fields():
123+
"""Test Step initialization with missing required fields."""
124+
with pytest.raises(ValidationError):
125+
Step(
126+
data_type="step",
127+
action="test_action",
128+
agent="test_agent",
129+
session_id="test_session",
130+
)

0 commit comments

Comments
 (0)