Skip to content

Commit af1681c

Browse files
Testcases
1 parent 78738ce commit af1681c

File tree

1 file changed

+135
-76
lines changed

1 file changed

+135
-76
lines changed

src/backend/tests/context/test_cosmos_memory.py

Lines changed: 135 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
os.environ["AZURE_OPENAI_API_VERSION"] = "2023-01-01"
1414
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://mock-openai-endpoint"
1515

16+
1617
async def async_iterable(mock_items):
1718
"""Helper to create an async iterable."""
1819
for item in mock_items:
1920
yield item
2021

22+
2123
@pytest.fixture(autouse=True)
2224
def mock_env_variables(monkeypatch):
2325
"""Mock all required environment variables."""
@@ -33,12 +35,6 @@ def mock_env_variables(monkeypatch):
3335
for key, value in env_vars.items():
3436
monkeypatch.setenv(key, value)
3537

36-
@pytest.fixture(autouse=True)
37-
def mock_azure_credentials():
38-
"""Mock Azure DefaultAzureCredential for all tests."""
39-
with patch("azure.identity.aio.DefaultAzureCredential") as mock_cred:
40-
mock_cred.return_value.get_token = AsyncMock(return_value={"token": "mock-token"})
41-
yield
4238

4339
@pytest.fixture
4440
def mock_cosmos_client():
@@ -48,6 +44,7 @@ def mock_cosmos_client():
4844
mock_client.create_container_if_not_exists.return_value = mock_container
4945
return mock_client, mock_container
5046

47+
5148
@pytest.fixture
5249
def mock_config(mock_cosmos_client):
5350
"""Fixture to patch Config with mock Cosmos DB client."""
@@ -57,21 +54,20 @@ def mock_config(mock_cosmos_client):
5754
), patch("src.backend.config.Config.COSMOSDB_CONTAINER", "mock-container"):
5855
yield
5956

57+
6058
@pytest.mark.asyncio
6159
async def test_initialize(mock_config, mock_cosmos_client):
6260
"""Test if the Cosmos DB container is initialized correctly."""
6361
mock_client, mock_container = mock_cosmos_client
6462
context = CosmosBufferedChatCompletionContext(
6563
session_id="test_session", user_id="test_user"
6664
)
67-
try:
68-
await context.initialize()
69-
mock_client.create_container_if_not_exists.assert_called_once_with(
70-
id="mock-container", partition_key=PartitionKey(path="/session_id")
71-
)
72-
assert context._container == mock_container
73-
finally:
74-
await context.close()
65+
await context.initialize()
66+
mock_client.create_container_if_not_exists.assert_called_once_with(
67+
id="mock-container", partition_key=PartitionKey(path="/session_id")
68+
)
69+
assert context._container == mock_container
70+
7571

7672
@pytest.mark.asyncio
7773
async def test_add_item(mock_config, mock_cosmos_client):
@@ -83,14 +79,13 @@ async def test_add_item(mock_config, mock_cosmos_client):
8379
context = CosmosBufferedChatCompletionContext(
8480
session_id="test_session", user_id="test_user"
8581
)
86-
try:
87-
await context.initialize()
88-
await context.add_item(mock_item)
89-
mock_container.create_item.assert_called_once_with(
90-
body={"id": "test-item", "data": "test-data"}
91-
)
92-
finally:
93-
await context.close()
82+
await context.initialize()
83+
await context.add_item(mock_item)
84+
85+
mock_container.create_item.assert_called_once_with(
86+
body={"id": "test-item", "data": "test-data"}
87+
)
88+
9489

9590
@pytest.mark.asyncio
9691
async def test_update_item(mock_config, mock_cosmos_client):
@@ -102,14 +97,13 @@ async def test_update_item(mock_config, mock_cosmos_client):
10297
context = CosmosBufferedChatCompletionContext(
10398
session_id="test_session", user_id="test_user"
10499
)
105-
try:
106-
await context.initialize()
107-
await context.update_item(mock_item)
108-
mock_container.upsert_item.assert_called_once_with(
109-
body={"id": "test-item", "data": "updated-data"}
110-
)
111-
finally:
112-
await context.close()
100+
await context.initialize()
101+
await context.update_item(mock_item)
102+
103+
mock_container.upsert_item.assert_called_once_with(
104+
body={"id": "test-item", "data": "updated-data"}
105+
)
106+
113107

114108
@pytest.mark.asyncio
115109
async def test_get_item_by_id(mock_config, mock_cosmos_client):
@@ -124,72 +118,128 @@ async def test_get_item_by_id(mock_config, mock_cosmos_client):
124118
context = CosmosBufferedChatCompletionContext(
125119
session_id="test_session", user_id="test_user"
126120
)
127-
try:
128-
await context.initialize()
129-
result = await context.get_item_by_id(
130-
"test-item", "test-partition", mock_model_class
131-
)
132-
assert result == "validated_item"
133-
mock_container.read_item.assert_called_once_with(
134-
item="test-item", partition_key="test-partition"
135-
)
136-
finally:
137-
await context.close()
121+
await context.initialize()
122+
result = await context.get_item_by_id(
123+
"test-item", "test-partition", mock_model_class
124+
)
125+
126+
assert result == "validated_item"
127+
mock_container.read_item.assert_called_once_with(
128+
item="test-item", partition_key="test-partition"
129+
)
130+
138131

139132
@pytest.mark.asyncio
140133
async def test_delete_item(mock_config, mock_cosmos_client):
141134
"""Test deleting an item from Cosmos DB."""
142135
_, mock_container = mock_cosmos_client
136+
143137
context = CosmosBufferedChatCompletionContext(
144138
session_id="test_session", user_id="test_user"
145139
)
146-
try:
147-
await context.initialize()
148-
await context.delete_item("test-item", "test-partition")
149-
mock_container.delete_item.assert_called_once_with(
150-
item="test-item", partition_key="test-partition"
151-
)
152-
finally:
153-
await context.close()
140+
await context.initialize()
141+
await context.delete_item("test-item", "test-partition")
142+
143+
mock_container.delete_item.assert_called_once_with(
144+
item="test-item", partition_key="test-partition"
145+
)
146+
154147

155148
@pytest.mark.asyncio
156-
async def test_get_data_by_invalid_type(mock_config, mock_cosmos_client):
157-
"""Test querying data with an invalid type."""
149+
async def test_add_plan(mock_config, mock_cosmos_client):
150+
"""Test adding a plan to Cosmos DB."""
151+
_, mock_container = mock_cosmos_client
152+
mock_plan = MagicMock()
153+
mock_plan.model_dump.return_value = {"id": "plan1", "data": "plan-data"}
154+
158155
context = CosmosBufferedChatCompletionContext(
159156
session_id="test_session", user_id="test_user"
160157
)
161-
try:
162-
result = await context.get_data_by_type("invalid_type")
163-
assert result == [] # Expect empty result for invalid type
164-
finally:
165-
await context.close()
158+
await context.initialize()
159+
await context.add_plan(mock_plan)
160+
161+
mock_container.create_item.assert_called_once_with(
162+
body={"id": "plan1", "data": "plan-data"}
163+
)
164+
166165

167166
@pytest.mark.asyncio
168-
async def test_get_plan_by_invalid_session(mock_config, mock_cosmos_client):
169-
"""Test retrieving a plan with an invalid session ID."""
167+
async def test_update_plan(mock_config, mock_cosmos_client):
168+
"""Test updating a plan in Cosmos DB."""
170169
_, mock_container = mock_cosmos_client
171-
mock_container.query_items.return_value = async_iterable([]) # No results
170+
mock_plan = MagicMock()
171+
mock_plan.model_dump.return_value = {"id": "plan1", "data": "updated-plan-data"}
172172

173173
context = CosmosBufferedChatCompletionContext(
174174
session_id="test_session", user_id="test_user"
175175
)
176-
try:
177-
await context.initialize()
178-
result = await context.get_plan_by_session("invalid_session")
179-
assert result is None
180-
finally:
181-
await context.close()
176+
await context.initialize()
177+
await context.update_plan(mock_plan)
178+
179+
mock_container.upsert_item.assert_called_once_with(
180+
body={"id": "plan1", "data": "updated-plan-data"}
181+
)
182+
182183

183184
@pytest.mark.asyncio
184-
async def test_close_without_initialization(mock_config):
185-
"""Test close method without prior initialization."""
185+
async def test_add_session(mock_config, mock_cosmos_client):
186+
"""Test adding a session to Cosmos DB."""
187+
_, mock_container = mock_cosmos_client
188+
mock_session = MagicMock()
189+
mock_session.model_dump.return_value = {"id": "session1", "data": "session-data"}
190+
186191
context = CosmosBufferedChatCompletionContext(
187192
session_id="test_session", user_id="test_user"
188193
)
189-
try:
190-
await context.close()
191-
except Exception as e:
192-
pytest.fail(f"Unexpected exception during close: {e}")
194+
await context.initialize()
195+
await context.add_session(mock_session)
196+
197+
mock_container.create_item.assert_called_once_with(
198+
body={"id": "session1", "data": "session-data"}
199+
)
200+
201+
202+
@pytest.mark.asyncio
203+
async def test_initialize_event(mock_config, mock_cosmos_client):
204+
"""Test the initialization event is set."""
205+
_, _ = mock_cosmos_client
206+
context = CosmosBufferedChatCompletionContext(
207+
session_id="test_session", user_id="test_user"
208+
)
209+
assert not context._initialized.is_set()
210+
await context.initialize()
211+
assert context._initialized.is_set()
212+
213+
214+
@pytest.mark.asyncio
215+
async def test_get_data_by_invalid_type(mock_config, mock_cosmos_client):
216+
"""Test querying data with an invalid type."""
217+
_, _ = mock_cosmos_client
218+
context = CosmosBufferedChatCompletionContext(
219+
session_id="test_session", user_id="test_user"
220+
)
221+
222+
result = await context.get_data_by_type("invalid_type")
223+
224+
assert result == [] # Expect empty result for invalid type
225+
226+
227+
@pytest.mark.asyncio
228+
async def test_get_plan_by_invalid_session(mock_config, mock_cosmos_client):
229+
"""Test retrieving a plan with an invalid session ID."""
230+
_, mock_container = mock_cosmos_client
231+
mock_container.query_items.return_value = async_iterable(
232+
[]
233+
) # No results for invalid session
234+
235+
context = CosmosBufferedChatCompletionContext(
236+
session_id="test_session", user_id="test_user"
237+
)
238+
await context.initialize()
239+
result = await context.get_plan_by_session("invalid_session")
240+
241+
assert result is None
242+
193243

194244
@pytest.mark.asyncio
195245
async def test_delete_item_error_handling(mock_config, mock_cosmos_client):
@@ -200,8 +250,17 @@ async def test_delete_item_error_handling(mock_config, mock_cosmos_client):
200250
context = CosmosBufferedChatCompletionContext(
201251
session_id="test_session", user_id="test_user"
202252
)
203-
try:
204-
await context.initialize()
205-
await context.delete_item("test-item", "test-partition")
206-
finally:
207-
await context.close()
253+
await context.initialize()
254+
await context.delete_item(
255+
"test-item", "test-partition"
256+
) # Expect no exception to propagate
257+
258+
259+
@pytest.mark.asyncio
260+
async def test_close_without_initialization(mock_config, mock_cosmos_client):
261+
"""Test close method without prior initialization."""
262+
context = CosmosBufferedChatCompletionContext(
263+
session_id="test_session", user_id="test_user"
264+
)
265+
# Expect no exceptions when closing uninitialized context
266+
await context.close()

0 commit comments

Comments
 (0)