Skip to content

Commit 1e0531d

Browse files
Testcases
1 parent af1681c commit 1e0531d

File tree

1 file changed

+68
-118
lines changed

1 file changed

+68
-118
lines changed

src/backend/tests/context/test_cosmos_memory.py

Lines changed: 68 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from azure.cosmos.partition_key import PartitionKey
55
from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext
66

7-
# Set environment variables globally before importing modules
7+
# Mock environment variables
88
os.environ["COSMOSDB_ENDPOINT"] = "https://mock-endpoint"
99
os.environ["COSMOSDB_KEY"] = "mock-key"
1010
os.environ["COSMOSDB_DATABASE"] = "mock-database"
@@ -62,11 +62,14 @@ async def test_initialize(mock_config, mock_cosmos_client):
6262
context = CosmosBufferedChatCompletionContext(
6363
session_id="test_session", user_id="test_user"
6464
)
65-
await context.initialize()
66-
mock_client.create_container_if_not_exists.assert_called_once_with(
67-
id="mock-container", partition_key=PartitionKey(path="/session_id")
68-
)
69-
assert context._container == mock_container
65+
try:
66+
await context.initialize()
67+
mock_client.create_container_if_not_exists.assert_called_once_with(
68+
id="mock-container", partition_key=PartitionKey(path="/session_id")
69+
)
70+
assert context._container == mock_container
71+
finally:
72+
await context.close()
7073

7174

7275
@pytest.mark.asyncio
@@ -79,12 +82,14 @@ async def test_add_item(mock_config, mock_cosmos_client):
7982
context = CosmosBufferedChatCompletionContext(
8083
session_id="test_session", user_id="test_user"
8184
)
82-
await context.initialize()
83-
await context.add_item(mock_item)
84-
85-
mock_container.create_item.assert_called_once_with(
86-
body={"id": "test-item", "data": "test-data"}
87-
)
85+
try:
86+
await context.initialize()
87+
await context.add_item(mock_item)
88+
mock_container.create_item.assert_called_once_with(
89+
body={"id": "test-item", "data": "test-data"}
90+
)
91+
finally:
92+
await context.close()
8893

8994

9095
@pytest.mark.asyncio
@@ -97,12 +102,14 @@ async def test_update_item(mock_config, mock_cosmos_client):
97102
context = CosmosBufferedChatCompletionContext(
98103
session_id="test_session", user_id="test_user"
99104
)
100-
await context.initialize()
101-
await context.update_item(mock_item)
102-
103-
mock_container.upsert_item.assert_called_once_with(
104-
body={"id": "test-item", "data": "updated-data"}
105-
)
105+
try:
106+
await context.initialize()
107+
await context.update_item(mock_item)
108+
mock_container.upsert_item.assert_called_once_with(
109+
body={"id": "test-item", "data": "updated-data"}
110+
)
111+
finally:
112+
await context.close()
106113

107114

108115
@pytest.mark.asyncio
@@ -118,15 +125,17 @@ async def test_get_item_by_id(mock_config, mock_cosmos_client):
118125
context = CosmosBufferedChatCompletionContext(
119126
session_id="test_session", user_id="test_user"
120127
)
121-
await context.initialize()
122-
result = await context.get_item_by_id(
123-
"test-item", "test-partition", mock_model_class
124-
)
125-
126-
assert result == "validated_item"
127-
mock_container.read_item.assert_called_once_with(
128-
item="test-item", partition_key="test-partition"
129-
)
128+
try:
129+
await context.initialize()
130+
result = await context.get_item_by_id(
131+
"test-item", "test-partition", mock_model_class
132+
)
133+
assert result == "validated_item"
134+
mock_container.read_item.assert_called_once_with(
135+
item="test-item", partition_key="test-partition"
136+
)
137+
finally:
138+
await context.close()
130139

131140

132141
@pytest.mark.asyncio
@@ -137,108 +146,58 @@ async def test_delete_item(mock_config, mock_cosmos_client):
137146
context = CosmosBufferedChatCompletionContext(
138147
session_id="test_session", user_id="test_user"
139148
)
140-
await context.initialize()
141-
await context.delete_item("test-item", "test-partition")
142-
143-
mock_container.delete_item.assert_called_once_with(
144-
item="test-item", partition_key="test-partition"
145-
)
146-
147-
148-
@pytest.mark.asyncio
149-
async def test_add_plan(mock_config, mock_cosmos_client):
150-
"""Test adding a plan to Cosmos DB."""
151-
_, mock_container = mock_cosmos_client
152-
mock_plan = MagicMock()
153-
mock_plan.model_dump.return_value = {"id": "plan1", "data": "plan-data"}
154-
155-
context = CosmosBufferedChatCompletionContext(
156-
session_id="test_session", user_id="test_user"
157-
)
158-
await context.initialize()
159-
await context.add_plan(mock_plan)
160-
161-
mock_container.create_item.assert_called_once_with(
162-
body={"id": "plan1", "data": "plan-data"}
163-
)
164-
165-
166-
@pytest.mark.asyncio
167-
async def test_update_plan(mock_config, mock_cosmos_client):
168-
"""Test updating a plan in Cosmos DB."""
169-
_, mock_container = mock_cosmos_client
170-
mock_plan = MagicMock()
171-
mock_plan.model_dump.return_value = {"id": "plan1", "data": "updated-plan-data"}
172-
173-
context = CosmosBufferedChatCompletionContext(
174-
session_id="test_session", user_id="test_user"
175-
)
176-
await context.initialize()
177-
await context.update_plan(mock_plan)
178-
179-
mock_container.upsert_item.assert_called_once_with(
180-
body={"id": "plan1", "data": "updated-plan-data"}
181-
)
149+
try:
150+
await context.initialize()
151+
await context.delete_item("test-item", "test-partition")
152+
mock_container.delete_item.assert_called_once_with(
153+
item="test-item", partition_key="test-partition"
154+
)
155+
finally:
156+
await context.close()
182157

183158

184159
@pytest.mark.asyncio
185-
async def test_add_session(mock_config, mock_cosmos_client):
186-
"""Test adding a session to Cosmos DB."""
187-
_, mock_container = mock_cosmos_client
188-
mock_session = MagicMock()
189-
mock_session.model_dump.return_value = {"id": "session1", "data": "session-data"}
190-
160+
async def test_close_without_initialization():
161+
"""Test closing the context without prior initialization."""
191162
context = CosmosBufferedChatCompletionContext(
192163
session_id="test_session", user_id="test_user"
193164
)
194-
await context.initialize()
195-
await context.add_session(mock_session)
196-
197-
mock_container.create_item.assert_called_once_with(
198-
body={"id": "session1", "data": "session-data"}
199-
)
165+
try:
166+
await context.close() # Should handle gracefully even if not initialized
167+
except Exception as e:
168+
pytest.fail(f"Unexpected exception during close: {e}")
200169

201170

202171
@pytest.mark.asyncio
203172
async def test_initialize_event(mock_config, mock_cosmos_client):
204-
"""Test the initialization event is set."""
173+
"""Test if the initialization flag is correctly set."""
205174
_, _ = mock_cosmos_client
206175
context = CosmosBufferedChatCompletionContext(
207176
session_id="test_session", user_id="test_user"
208177
)
209178
assert not context._initialized.is_set()
210-
await context.initialize()
211-
assert context._initialized.is_set()
212-
213-
214-
@pytest.mark.asyncio
215-
async def test_get_data_by_invalid_type(mock_config, mock_cosmos_client):
216-
"""Test querying data with an invalid type."""
217-
_, _ = mock_cosmos_client
218-
context = CosmosBufferedChatCompletionContext(
219-
session_id="test_session", user_id="test_user"
220-
)
221-
222-
result = await context.get_data_by_type("invalid_type")
223-
224-
assert result == [] # Expect empty result for invalid type
179+
try:
180+
await context.initialize()
181+
assert context._initialized.is_set()
182+
finally:
183+
await context.close()
225184

226185

227186
@pytest.mark.asyncio
228187
async def test_get_plan_by_invalid_session(mock_config, mock_cosmos_client):
229188
"""Test retrieving a plan with an invalid session ID."""
230189
_, mock_container = mock_cosmos_client
231-
mock_container.query_items.return_value = async_iterable(
232-
[]
233-
) # No results for invalid session
190+
mock_container.query_items.return_value = async_iterable([]) # No results
234191

235192
context = CosmosBufferedChatCompletionContext(
236193
session_id="test_session", user_id="test_user"
237194
)
238-
await context.initialize()
239-
result = await context.get_plan_by_session("invalid_session")
240-
241-
assert result is None
195+
try:
196+
await context.initialize()
197+
result = await context.get_plan_by_session("invalid_session")
198+
assert result is None
199+
finally:
200+
await context.close()
242201

243202

244203
@pytest.mark.asyncio
@@ -250,17 +209,8 @@ async def test_delete_item_error_handling(mock_config, mock_cosmos_client):
250209
context = CosmosBufferedChatCompletionContext(
251210
session_id="test_session", user_id="test_user"
252211
)
253-
await context.initialize()
254-
await context.delete_item(
255-
"test-item", "test-partition"
256-
) # Expect no exception to propagate
257-
258-
259-
@pytest.mark.asyncio
260-
async def test_close_without_initialization(mock_config, mock_cosmos_client):
261-
"""Test close method without prior initialization."""
262-
context = CosmosBufferedChatCompletionContext(
263-
session_id="test_session", user_id="test_user"
264-
)
265-
# Expect no exceptions when closing uninitialized context
266-
await context.close()
212+
try:
213+
await context.initialize()
214+
await context.delete_item("test-item", "test-partition")
215+
finally:
216+
await context.close()

0 commit comments

Comments
 (0)