Skip to content

Commit 701695a

Browse files
Testcases
1 parent 945062f commit 701695a

File tree

2 files changed

+148
-2
lines changed

2 files changed

+148
-2
lines changed

src/backend/handlers/runtime_interrupt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
class NeedsUserInputHandler(DefaultInterventionHandler):
1212
def __init__(self):
13-
self.question_for_human: Optional[GetHumanInputMessage] = None # type: ignore
13+
self.question_for_human: Optional[GetHumanInputMessage] = None
1414
self.messages: List[Dict[str, Any]] = []
1515

1616
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any:

src/backend/tests/context/test_cosmos_memory.py

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
from azure.cosmos.partition_key import PartitionKey
1515
from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext
1616

17+
async def async_iterable(mock_items):
18+
"""Helper to create an async iterable."""
19+
for item in mock_items:
20+
yield item
21+
1722
@pytest.fixture(autouse=True)
1823
def mock_env_variables(monkeypatch):
1924
"""Mock all required environment variables."""
@@ -58,4 +63,145 @@ async def test_initialize(mock_config, mock_cosmos_client):
5863
id="mock-container",
5964
partition_key=PartitionKey(path="/session_id")
6065
)
61-
assert context._container == mock_container
66+
assert context._container == mock_container
67+
68+
@pytest.mark.asyncio
69+
async def test_add_item(mock_config, mock_cosmos_client):
70+
"""Test adding an item to Cosmos DB."""
71+
_, mock_container = mock_cosmos_client
72+
mock_item = MagicMock()
73+
mock_item.model_dump.return_value = {"id": "test-item", "data": "test-data"}
74+
75+
context = CosmosBufferedChatCompletionContext(session_id="test_session", user_id="test_user")
76+
await context.initialize()
77+
await context.add_item(mock_item)
78+
79+
mock_container.create_item.assert_called_once_with(body={"id": "test-item", "data": "test-data"})
80+
81+
@pytest.mark.asyncio
82+
async def test_update_item(mock_config, mock_cosmos_client):
83+
"""Test updating an item in Cosmos DB."""
84+
_, mock_container = mock_cosmos_client
85+
mock_item = MagicMock()
86+
mock_item.model_dump.return_value = {"id": "test-item", "data": "updated-data"}
87+
88+
context = CosmosBufferedChatCompletionContext(session_id="test_session", user_id="test_user")
89+
await context.initialize()
90+
await context.update_item(mock_item)
91+
92+
mock_container.upsert_item.assert_called_once_with(body={"id": "test-item", "data": "updated-data"})
93+
94+
@pytest.mark.asyncio
95+
async def test_get_item_by_id(mock_config, mock_cosmos_client):
96+
"""Test retrieving an item by ID from Cosmos DB."""
97+
_, mock_container = mock_cosmos_client
98+
mock_item = {"id": "test-item", "data": "retrieved-data"}
99+
mock_container.read_item.return_value = mock_item
100+
101+
mock_model_class = MagicMock()
102+
mock_model_class.model_validate.return_value = "validated_item"
103+
104+
context = CosmosBufferedChatCompletionContext(session_id="test_session", user_id="test_user")
105+
await context.initialize()
106+
result = await context.get_item_by_id("test-item", "test-partition", mock_model_class)
107+
108+
assert result == "validated_item"
109+
mock_container.read_item.assert_called_once_with(item="test-item", partition_key="test-partition")
110+
111+
@pytest.mark.asyncio
112+
async def test_delete_item(mock_config, mock_cosmos_client):
113+
"""Test deleting an item from Cosmos DB."""
114+
_, mock_container = mock_cosmos_client
115+
116+
context = CosmosBufferedChatCompletionContext(session_id="test_session", user_id="test_user")
117+
await context.initialize()
118+
await context.delete_item("test-item", "test-partition")
119+
120+
mock_container.delete_item.assert_called_once_with(item="test-item", partition_key="test-partition")
121+
122+
@pytest.mark.asyncio
123+
async def test_add_plan(mock_config, mock_cosmos_client):
124+
"""Test adding a plan to Cosmos DB."""
125+
_, mock_container = mock_cosmos_client
126+
mock_plan = MagicMock()
127+
mock_plan.model_dump.return_value = {"id": "plan1", "data": "plan-data"}
128+
129+
context = CosmosBufferedChatCompletionContext(session_id="test_session", user_id="test_user")
130+
await context.initialize()
131+
await context.add_plan(mock_plan)
132+
133+
mock_container.create_item.assert_called_once_with(body={"id": "plan1", "data": "plan-data"})
134+
135+
@pytest.mark.asyncio
136+
async def test_update_plan(mock_config, mock_cosmos_client):
137+
"""Test updating a plan in Cosmos DB."""
138+
_, mock_container = mock_cosmos_client
139+
mock_plan = MagicMock()
140+
mock_plan.model_dump.return_value = {"id": "plan1", "data": "updated-plan-data"}
141+
142+
context = CosmosBufferedChatCompletionContext(session_id="test_session", user_id="test_user")
143+
await context.initialize()
144+
await context.update_plan(mock_plan)
145+
146+
mock_container.upsert_item.assert_called_once_with(body={"id": "plan1", "data": "updated-plan-data"})
147+
148+
@pytest.mark.asyncio
149+
async def test_add_session(mock_config, mock_cosmos_client):
150+
"""Test adding a session to Cosmos DB."""
151+
_, mock_container = mock_cosmos_client
152+
mock_session = MagicMock()
153+
mock_session.model_dump.return_value = {"id": "session1", "data": "session-data"}
154+
155+
context = CosmosBufferedChatCompletionContext(session_id="test_session", user_id="test_user")
156+
await context.initialize()
157+
await context.add_session(mock_session)
158+
159+
mock_container.create_item.assert_called_once_with(body={"id": "session1", "data": "session-data"})
160+
161+
@pytest.mark.asyncio
162+
async def test_initialize_event(mock_config, mock_cosmos_client):
163+
"""Test the initialization event is set."""
164+
_, _ = mock_cosmos_client
165+
context = CosmosBufferedChatCompletionContext(session_id="test_session", user_id="test_user")
166+
assert not context._initialized.is_set()
167+
await context.initialize()
168+
assert context._initialized.is_set()
169+
170+
@pytest.mark.asyncio
171+
async def test_get_data_by_invalid_type(mock_config, mock_cosmos_client):
172+
"""Test querying data with an invalid type."""
173+
_, _ = mock_cosmos_client
174+
context = CosmosBufferedChatCompletionContext(session_id="test_session", user_id="test_user")
175+
176+
result = await context.get_data_by_type("invalid_type")
177+
178+
assert result == [] # Expect empty result for invalid type
179+
180+
@pytest.mark.asyncio
181+
async def test_get_plan_by_invalid_session(mock_config, mock_cosmos_client):
182+
"""Test retrieving a plan with an invalid session ID."""
183+
_, mock_container = mock_cosmos_client
184+
mock_container.query_items.return_value = async_iterable([]) # No results for invalid session
185+
186+
context = CosmosBufferedChatCompletionContext(session_id="test_session", user_id="test_user")
187+
await context.initialize()
188+
result = await context.get_plan_by_session("invalid_session")
189+
190+
assert result is None
191+
192+
@pytest.mark.asyncio
193+
async def test_delete_item_error_handling(mock_config, mock_cosmos_client):
194+
"""Test error handling when deleting an item."""
195+
_, mock_container = mock_cosmos_client
196+
mock_container.delete_item.side_effect = Exception("Delete error")
197+
198+
context = CosmosBufferedChatCompletionContext(session_id="test_session", user_id="test_user")
199+
await context.initialize()
200+
await context.delete_item("test-item", "test-partition") # Expect no exception to propagate
201+
202+
@pytest.mark.asyncio
203+
async def test_close_without_initialization(mock_config, mock_cosmos_client):
204+
"""Test close method without prior initialization."""
205+
context = CosmosBufferedChatCompletionContext(session_id="test_session", user_id="test_user")
206+
# Expect no exceptions when closing uninitialized context
207+
await context.close()

0 commit comments

Comments
 (0)