Skip to content

Commit 6694d2c

Browse files
Testcases
1 parent 133e876 commit 6694d2c

File tree

1 file changed

+124
-4
lines changed

1 file changed

+124
-4
lines changed

src/backend/tests/context/test_cosmos_memory.py

Lines changed: 124 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ def mock_env_variables(monkeypatch):
3737
monkeypatch.setenv(key, value)
3838

3939

40+
@pytest.fixture
41+
def mock_azure_credentials():
42+
"""Mock Azure DefaultAzureCredential for all tests."""
43+
with patch("azure.identity.aio.DefaultAzureCredential") as mock_cred:
44+
mock_cred.return_value.get_token = AsyncMock(return_value={"token": "mock-token"})
45+
yield
46+
47+
4048
@pytest.fixture
4149
def mock_cosmos_client():
4250
"""Fixture for mocking Cosmos DB client and container."""
@@ -71,13 +79,12 @@ async def test_initialize(mock_config, mock_cosmos_client):
7179

7280

7381
@pytest.mark.asyncio
74-
async def test_close_without_initialization():
82+
async def test_close_without_initialization(mock_config):
7583
"""Test closing the context without prior initialization."""
7684
async with CosmosBufferedChatCompletionContext(
7785
session_id="test_session", user_id="test_user"
78-
) as context:
79-
# Ensure close is safe even if not explicitly initialized
80-
pass
86+
):
87+
pass # Expect no errors when exiting context
8188

8289

8390
@pytest.mark.asyncio
@@ -95,3 +102,116 @@ async def test_add_item(mock_config, mock_cosmos_client):
95102
mock_container.create_item.assert_called_once_with(
96103
body={"id": "test-item", "data": "test-data"}
97104
)
105+
106+
107+
@pytest.mark.asyncio
108+
async def test_update_item(mock_config, mock_cosmos_client):
109+
"""Test updating an item in Cosmos DB."""
110+
_, mock_container = mock_cosmos_client
111+
mock_item = MagicMock()
112+
mock_item.model_dump.return_value = {"id": "test-item", "data": "updated-data"}
113+
114+
async with CosmosBufferedChatCompletionContext(
115+
session_id="test_session", user_id="test_user"
116+
) as context:
117+
await context.initialize()
118+
await context.update_item(mock_item)
119+
mock_container.upsert_item.assert_called_once_with(
120+
body={"id": "test-item", "data": "updated-data"}
121+
)
122+
123+
124+
@pytest.mark.asyncio
125+
async def test_get_item_by_id(mock_config, mock_cosmos_client):
126+
"""Test retrieving an item by ID from Cosmos DB."""
127+
_, mock_container = mock_cosmos_client
128+
mock_item = {"id": "test-item", "data": "retrieved-data"}
129+
mock_container.read_item.return_value = mock_item
130+
131+
mock_model_class = MagicMock()
132+
mock_model_class.model_validate.return_value = "validated_item"
133+
134+
async with CosmosBufferedChatCompletionContext(
135+
session_id="test_session", user_id="test_user"
136+
) as context:
137+
await context.initialize()
138+
result = await context.get_item_by_id(
139+
"test-item", "test-partition", mock_model_class
140+
)
141+
142+
assert result == "validated_item"
143+
mock_container.read_item.assert_called_once_with(
144+
item="test-item", partition_key="test-partition"
145+
)
146+
147+
148+
@pytest.mark.asyncio
149+
async def test_delete_item(mock_config, mock_cosmos_client):
150+
"""Test deleting an item from Cosmos DB."""
151+
_, mock_container = mock_cosmos_client
152+
153+
async with CosmosBufferedChatCompletionContext(
154+
session_id="test_session", user_id="test_user"
155+
) as context:
156+
await context.initialize()
157+
await context.delete_item("test-item", "test-partition")
158+
159+
mock_container.delete_item.assert_called_once_with(
160+
item="test-item", partition_key="test-partition"
161+
)
162+
163+
164+
@pytest.mark.asyncio
165+
async def test_add_plan(mock_config, mock_cosmos_client):
166+
"""Test adding a plan to Cosmos DB."""
167+
_, mock_container = mock_cosmos_client
168+
mock_plan = MagicMock()
169+
mock_plan.model_dump.return_value = {"id": "plan1", "data": "plan-data"}
170+
171+
async with CosmosBufferedChatCompletionContext(
172+
session_id="test_session", user_id="test_user"
173+
) as context:
174+
await context.initialize()
175+
await context.add_plan(mock_plan)
176+
177+
mock_container.create_item.assert_called_once_with(
178+
body={"id": "plan1", "data": "plan-data"}
179+
)
180+
181+
182+
@pytest.mark.asyncio
183+
async def test_update_plan(mock_config, mock_cosmos_client):
184+
"""Test updating a plan in Cosmos DB."""
185+
_, mock_container = mock_cosmos_client
186+
mock_plan = MagicMock()
187+
mock_plan.model_dump.return_value = {
188+
"id": "plan1",
189+
"data": "updated-plan-data",
190+
}
191+
192+
async with CosmosBufferedChatCompletionContext(
193+
session_id="test_session", user_id="test_user"
194+
) as context:
195+
await context.initialize()
196+
await context.update_plan(mock_plan)
197+
198+
mock_container.upsert_item.assert_called_once_with(
199+
body={"id": "plan1", "data": "updated-plan-data"}
200+
)
201+
202+
203+
@pytest.mark.asyncio
204+
async def test_get_plan_by_invalid_session(mock_config, mock_cosmos_client):
205+
"""Test retrieving a plan with an invalid session ID."""
206+
_, mock_container = mock_cosmos_client
207+
mock_container.query_items.return_value = async_iterable(
208+
[]
209+
) # No results for invalid session
210+
211+
async with CosmosBufferedChatCompletionContext(
212+
session_id="test_session", user_id="test_user"
213+
) as context:
214+
await context.initialize()
215+
result = await context.get_plan_by_session("invalid_session")
216+
217+
assert result is None

0 commit comments

Comments
 (0)