Skip to content

Commit 78738ce

Browse files
Testcases
1 parent e2a77bc commit 78738ce

File tree

1 file changed

+76
-135
lines changed

1 file changed

+76
-135
lines changed

src/backend/tests/context/test_cosmos_memory.py

Lines changed: 76 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,11 @@
1313
os.environ["AZURE_OPENAI_API_VERSION"] = "2023-01-01"
1414
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://mock-openai-endpoint"
1515

16-
1716
async def async_iterable(mock_items):
1817
"""Helper to create an async iterable."""
1918
for item in mock_items:
2019
yield item
2120

22-
2321
@pytest.fixture(autouse=True)
2422
def mock_env_variables(monkeypatch):
2523
"""Mock all required environment variables."""
@@ -35,6 +33,12 @@ def mock_env_variables(monkeypatch):
3533
for key, value in env_vars.items():
3634
monkeypatch.setenv(key, value)
3735

36+
@pytest.fixture(autouse=True)
37+
def mock_azure_credentials():
38+
"""Mock Azure DefaultAzureCredential for all tests."""
39+
with patch("azure.identity.aio.DefaultAzureCredential") as mock_cred:
40+
mock_cred.return_value.get_token = AsyncMock(return_value={"token": "mock-token"})
41+
yield
3842

3943
@pytest.fixture
4044
def mock_cosmos_client():
@@ -44,7 +48,6 @@ def mock_cosmos_client():
4448
mock_client.create_container_if_not_exists.return_value = mock_container
4549
return mock_client, mock_container
4650

47-
4851
@pytest.fixture
4952
def mock_config(mock_cosmos_client):
5053
"""Fixture to patch Config with mock Cosmos DB client."""
@@ -54,20 +57,21 @@ def mock_config(mock_cosmos_client):
5457
), patch("src.backend.config.Config.COSMOSDB_CONTAINER", "mock-container"):
5558
yield
5659

57-
5860
@pytest.mark.asyncio
5961
async def test_initialize(mock_config, mock_cosmos_client):
6062
"""Test if the Cosmos DB container is initialized correctly."""
6163
mock_client, mock_container = mock_cosmos_client
6264
context = CosmosBufferedChatCompletionContext(
6365
session_id="test_session", user_id="test_user"
6466
)
65-
await context.initialize()
66-
mock_client.create_container_if_not_exists.assert_called_once_with(
67-
id="mock-container", partition_key=PartitionKey(path="/session_id")
68-
)
69-
assert context._container == mock_container
70-
67+
try:
68+
await context.initialize()
69+
mock_client.create_container_if_not_exists.assert_called_once_with(
70+
id="mock-container", partition_key=PartitionKey(path="/session_id")
71+
)
72+
assert context._container == mock_container
73+
finally:
74+
await context.close()
7175

7276
@pytest.mark.asyncio
7377
async def test_add_item(mock_config, mock_cosmos_client):
@@ -79,13 +83,14 @@ async def test_add_item(mock_config, mock_cosmos_client):
7983
context = CosmosBufferedChatCompletionContext(
8084
session_id="test_session", user_id="test_user"
8185
)
82-
await context.initialize()
83-
await context.add_item(mock_item)
84-
85-
mock_container.create_item.assert_called_once_with(
86-
body={"id": "test-item", "data": "test-data"}
87-
)
88-
86+
try:
87+
await context.initialize()
88+
await context.add_item(mock_item)
89+
mock_container.create_item.assert_called_once_with(
90+
body={"id": "test-item", "data": "test-data"}
91+
)
92+
finally:
93+
await context.close()
8994

9095
@pytest.mark.asyncio
9196
async def test_update_item(mock_config, mock_cosmos_client):
@@ -97,13 +102,14 @@ async def test_update_item(mock_config, mock_cosmos_client):
97102
context = CosmosBufferedChatCompletionContext(
98103
session_id="test_session", user_id="test_user"
99104
)
100-
await context.initialize()
101-
await context.update_item(mock_item)
102-
103-
mock_container.upsert_item.assert_called_once_with(
104-
body={"id": "test-item", "data": "updated-data"}
105-
)
106-
105+
try:
106+
await context.initialize()
107+
await context.update_item(mock_item)
108+
mock_container.upsert_item.assert_called_once_with(
109+
body={"id": "test-item", "data": "updated-data"}
110+
)
111+
finally:
112+
await context.close()
107113

108114
@pytest.mark.asyncio
109115
async def test_get_item_by_id(mock_config, mock_cosmos_client):
@@ -118,128 +124,72 @@ async def test_get_item_by_id(mock_config, mock_cosmos_client):
118124
context = CosmosBufferedChatCompletionContext(
119125
session_id="test_session", user_id="test_user"
120126
)
121-
await context.initialize()
122-
result = await context.get_item_by_id(
123-
"test-item", "test-partition", mock_model_class
124-
)
125-
126-
assert result == "validated_item"
127-
mock_container.read_item.assert_called_once_with(
128-
item="test-item", partition_key="test-partition"
129-
)
130-
127+
try:
128+
await context.initialize()
129+
result = await context.get_item_by_id(
130+
"test-item", "test-partition", mock_model_class
131+
)
132+
assert result == "validated_item"
133+
mock_container.read_item.assert_called_once_with(
134+
item="test-item", partition_key="test-partition"
135+
)
136+
finally:
137+
await context.close()
131138

132139
@pytest.mark.asyncio
133140
async def test_delete_item(mock_config, mock_cosmos_client):
134141
"""Test deleting an item from Cosmos DB."""
135142
_, mock_container = mock_cosmos_client
136-
137-
context = CosmosBufferedChatCompletionContext(
138-
session_id="test_session", user_id="test_user"
139-
)
140-
await context.initialize()
141-
await context.delete_item("test-item", "test-partition")
142-
143-
mock_container.delete_item.assert_called_once_with(
144-
item="test-item", partition_key="test-partition"
145-
)
146-
147-
148-
@pytest.mark.asyncio
149-
async def test_add_plan(mock_config, mock_cosmos_client):
150-
"""Test adding a plan to Cosmos DB."""
151-
_, mock_container = mock_cosmos_client
152-
mock_plan = MagicMock()
153-
mock_plan.model_dump.return_value = {"id": "plan1", "data": "plan-data"}
154-
155-
context = CosmosBufferedChatCompletionContext(
156-
session_id="test_session", user_id="test_user"
157-
)
158-
await context.initialize()
159-
await context.add_plan(mock_plan)
160-
161-
mock_container.create_item.assert_called_once_with(
162-
body={"id": "plan1", "data": "plan-data"}
163-
)
164-
165-
166-
@pytest.mark.asyncio
167-
async def test_update_plan(mock_config, mock_cosmos_client):
168-
"""Test updating a plan in Cosmos DB."""
169-
_, mock_container = mock_cosmos_client
170-
mock_plan = MagicMock()
171-
mock_plan.model_dump.return_value = {"id": "plan1", "data": "updated-plan-data"}
172-
173143
context = CosmosBufferedChatCompletionContext(
174144
session_id="test_session", user_id="test_user"
175145
)
176-
await context.initialize()
177-
await context.update_plan(mock_plan)
178-
179-
mock_container.upsert_item.assert_called_once_with(
180-
body={"id": "plan1", "data": "updated-plan-data"}
181-
)
182-
183-
184-
@pytest.mark.asyncio
185-
async def test_add_session(mock_config, mock_cosmos_client):
186-
"""Test adding a session to Cosmos DB."""
187-
_, mock_container = mock_cosmos_client
188-
mock_session = MagicMock()
189-
mock_session.model_dump.return_value = {"id": "session1", "data": "session-data"}
190-
191-
context = CosmosBufferedChatCompletionContext(
192-
session_id="test_session", user_id="test_user"
193-
)
194-
await context.initialize()
195-
await context.add_session(mock_session)
196-
197-
mock_container.create_item.assert_called_once_with(
198-
body={"id": "session1", "data": "session-data"}
199-
)
200-
201-
202-
@pytest.mark.asyncio
203-
async def test_initialize_event(mock_config, mock_cosmos_client):
204-
"""Test the initialization event is set."""
205-
_, _ = mock_cosmos_client
206-
context = CosmosBufferedChatCompletionContext(
207-
session_id="test_session", user_id="test_user"
208-
)
209-
assert not context._initialized.is_set()
210-
await context.initialize()
211-
assert context._initialized.is_set()
212-
146+
try:
147+
await context.initialize()
148+
await context.delete_item("test-item", "test-partition")
149+
mock_container.delete_item.assert_called_once_with(
150+
item="test-item", partition_key="test-partition"
151+
)
152+
finally:
153+
await context.close()
213154

214155
@pytest.mark.asyncio
215156
async def test_get_data_by_invalid_type(mock_config, mock_cosmos_client):
216157
"""Test querying data with an invalid type."""
217-
_, _ = mock_cosmos_client
218158
context = CosmosBufferedChatCompletionContext(
219159
session_id="test_session", user_id="test_user"
220160
)
221-
222-
result = await context.get_data_by_type("invalid_type")
223-
224-
assert result == [] # Expect empty result for invalid type
225-
161+
try:
162+
result = await context.get_data_by_type("invalid_type")
163+
assert result == [] # Expect empty result for invalid type
164+
finally:
165+
await context.close()
226166

227167
@pytest.mark.asyncio
228168
async def test_get_plan_by_invalid_session(mock_config, mock_cosmos_client):
229169
"""Test retrieving a plan with an invalid session ID."""
230170
_, mock_container = mock_cosmos_client
231-
mock_container.query_items.return_value = async_iterable(
232-
[]
233-
) # No results for invalid session
171+
mock_container.query_items.return_value = async_iterable([]) # No results
234172

235173
context = CosmosBufferedChatCompletionContext(
236174
session_id="test_session", user_id="test_user"
237175
)
238-
await context.initialize()
239-
result = await context.get_plan_by_session("invalid_session")
240-
241-
assert result is None
176+
try:
177+
await context.initialize()
178+
result = await context.get_plan_by_session("invalid_session")
179+
assert result is None
180+
finally:
181+
await context.close()
242182

183+
@pytest.mark.asyncio
184+
async def test_close_without_initialization(mock_config):
185+
"""Test close method without prior initialization."""
186+
context = CosmosBufferedChatCompletionContext(
187+
session_id="test_session", user_id="test_user"
188+
)
189+
try:
190+
await context.close()
191+
except Exception as e:
192+
pytest.fail(f"Unexpected exception during close: {e}")
243193

244194
@pytest.mark.asyncio
245195
async def test_delete_item_error_handling(mock_config, mock_cosmos_client):
@@ -250,17 +200,8 @@ async def test_delete_item_error_handling(mock_config, mock_cosmos_client):
250200
context = CosmosBufferedChatCompletionContext(
251201
session_id="test_session", user_id="test_user"
252202
)
253-
await context.initialize()
254-
await context.delete_item(
255-
"test-item", "test-partition"
256-
) # Expect no exception to propagate
257-
258-
259-
@pytest.mark.asyncio
260-
async def test_close_without_initialization(mock_config, mock_cosmos_client):
261-
"""Test close method without prior initialization."""
262-
context = CosmosBufferedChatCompletionContext(
263-
session_id="test_session", user_id="test_user"
264-
)
265-
# Expect no exceptions when closing uninitialized context
266-
await context.close()
203+
try:
204+
await context.initialize()
205+
await context.delete_item("test-item", "test-partition")
206+
finally:
207+
await context.close()

0 commit comments

Comments
 (0)