Skip to content

Commit b1182d3

Browse files
Testcases
1 parent 1e0531d commit b1182d3

File tree

1 file changed

+118
-68
lines changed

1 file changed

+118
-68
lines changed

src/backend/tests/context/test_cosmos_memory.py

Lines changed: 118 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from azure.cosmos.partition_key import PartitionKey
55
from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext
66

7-
# Mock environment variables
7+
# Set environment variables globally before importing modules
88
os.environ["COSMOSDB_ENDPOINT"] = "https://mock-endpoint"
99
os.environ["COSMOSDB_KEY"] = "mock-key"
1010
os.environ["COSMOSDB_DATABASE"] = "mock-database"
@@ -62,14 +62,11 @@ async def test_initialize(mock_config, mock_cosmos_client):
6262
context = CosmosBufferedChatCompletionContext(
6363
session_id="test_session", user_id="test_user"
6464
)
65-
try:
66-
await context.initialize()
67-
mock_client.create_container_if_not_exists.assert_called_once_with(
68-
id="mock-container", partition_key=PartitionKey(path="/session_id")
69-
)
70-
assert context._container == mock_container
71-
finally:
72-
await context.close()
65+
await context.initialize()
66+
mock_client.create_container_if_not_exists.assert_called_once_with(
67+
id="mock-container", partition_key=PartitionKey(path="/session_id")
68+
)
69+
assert context._container == mock_container
7370

7471

7572
@pytest.mark.asyncio
@@ -82,14 +79,12 @@ async def test_add_item(mock_config, mock_cosmos_client):
8279
context = CosmosBufferedChatCompletionContext(
8380
session_id="test_session", user_id="test_user"
8481
)
85-
try:
86-
await context.initialize()
87-
await context.add_item(mock_item)
88-
mock_container.create_item.assert_called_once_with(
89-
body={"id": "test-item", "data": "test-data"}
90-
)
91-
finally:
92-
await context.close()
82+
await context.initialize()
83+
await context.add_item(mock_item)
84+
85+
mock_container.create_item.assert_called_once_with(
86+
body={"id": "test-item", "data": "test-data"}
87+
)
9388

9489

9590
@pytest.mark.asyncio
@@ -102,14 +97,12 @@ async def test_update_item(mock_config, mock_cosmos_client):
10297
context = CosmosBufferedChatCompletionContext(
10398
session_id="test_session", user_id="test_user"
10499
)
105-
try:
106-
await context.initialize()
107-
await context.update_item(mock_item)
108-
mock_container.upsert_item.assert_called_once_with(
109-
body={"id": "test-item", "data": "updated-data"}
110-
)
111-
finally:
112-
await context.close()
100+
await context.initialize()
101+
await context.update_item(mock_item)
102+
103+
mock_container.upsert_item.assert_called_once_with(
104+
body={"id": "test-item", "data": "updated-data"}
105+
)
113106

114107

115108
@pytest.mark.asyncio
@@ -125,17 +118,15 @@ async def test_get_item_by_id(mock_config, mock_cosmos_client):
125118
context = CosmosBufferedChatCompletionContext(
126119
session_id="test_session", user_id="test_user"
127120
)
128-
try:
129-
await context.initialize()
130-
result = await context.get_item_by_id(
131-
"test-item", "test-partition", mock_model_class
132-
)
133-
assert result == "validated_item"
134-
mock_container.read_item.assert_called_once_with(
135-
item="test-item", partition_key="test-partition"
136-
)
137-
finally:
138-
await context.close()
121+
await context.initialize()
122+
result = await context.get_item_by_id(
123+
"test-item", "test-partition", mock_model_class
124+
)
125+
126+
assert result == "validated_item"
127+
mock_container.read_item.assert_called_once_with(
128+
item="test-item", partition_key="test-partition"
129+
)
139130

140131

141132
@pytest.mark.asyncio
@@ -146,58 +137,108 @@ async def test_delete_item(mock_config, mock_cosmos_client):
146137
context = CosmosBufferedChatCompletionContext(
147138
session_id="test_session", user_id="test_user"
148139
)
149-
try:
150-
await context.initialize()
151-
await context.delete_item("test-item", "test-partition")
152-
mock_container.delete_item.assert_called_once_with(
153-
item="test-item", partition_key="test-partition"
154-
)
155-
finally:
156-
await context.close()
140+
await context.initialize()
141+
await context.delete_item("test-item", "test-partition")
142+
143+
mock_container.delete_item.assert_called_once_with(
144+
item="test-item", partition_key="test-partition"
145+
)
146+
147+
148+
@pytest.mark.asyncio
149+
async def test_add_plan(mock_config, mock_cosmos_client):
150+
"""Test adding a plan to Cosmos DB."""
151+
_, mock_container = mock_cosmos_client
152+
mock_plan = MagicMock()
153+
mock_plan.model_dump.return_value = {"id": "plan1", "data": "plan-data"}
154+
155+
context = CosmosBufferedChatCompletionContext(
156+
session_id="test_session", user_id="test_user"
157+
)
158+
await context.initialize()
159+
await context.add_plan(mock_plan)
160+
161+
mock_container.create_item.assert_called_once_with(
162+
body={"id": "plan1", "data": "plan-data"}
163+
)
164+
165+
166+
@pytest.mark.asyncio
167+
async def test_update_plan(mock_config, mock_cosmos_client):
168+
"""Test updating a plan in Cosmos DB."""
169+
_, mock_container = mock_cosmos_client
170+
mock_plan = MagicMock()
171+
mock_plan.model_dump.return_value = {"id": "plan1", "data": "updated-plan-data"}
172+
173+
context = CosmosBufferedChatCompletionContext(
174+
session_id="test_session", user_id="test_user"
175+
)
176+
await context.initialize()
177+
await context.update_plan(mock_plan)
178+
179+
mock_container.upsert_item.assert_called_once_with(
180+
body={"id": "plan1", "data": "updated-plan-data"}
181+
)
157182

158183

159184
@pytest.mark.asyncio
160-
async def test_close_without_initialization():
161-
"""Test closing the context without prior initialization."""
185+
async def test_add_session(mock_config, mock_cosmos_client):
186+
"""Test adding a session to Cosmos DB."""
187+
_, mock_container = mock_cosmos_client
188+
mock_session = MagicMock()
189+
mock_session.model_dump.return_value = {"id": "session1", "data": "session-data"}
190+
162191
context = CosmosBufferedChatCompletionContext(
163192
session_id="test_session", user_id="test_user"
164193
)
165-
try:
166-
await context.close() # Should handle gracefully even if not initialized
167-
except Exception as e:
168-
pytest.fail(f"Unexpected exception during close: {e}")
194+
await context.initialize()
195+
await context.add_session(mock_session)
196+
197+
mock_container.create_item.assert_called_once_with(
198+
body={"id": "session1", "data": "session-data"}
199+
)
169200

170201

171202
@pytest.mark.asyncio
172203
async def test_initialize_event(mock_config, mock_cosmos_client):
173-
"""Test if the initialization flag is correctly set."""
204+
"""Test the initialization event is set."""
174205
_, _ = mock_cosmos_client
175206
context = CosmosBufferedChatCompletionContext(
176207
session_id="test_session", user_id="test_user"
177208
)
178209
assert not context._initialized.is_set()
179-
try:
180-
await context.initialize()
181-
assert context._initialized.is_set()
182-
finally:
183-
await context.close()
210+
await context.initialize()
211+
assert context._initialized.is_set()
212+
213+
214+
@pytest.mark.asyncio
215+
async def test_get_data_by_invalid_type(mock_config, mock_cosmos_client):
216+
"""Test querying data with an invalid type."""
217+
_, _ = mock_cosmos_client
218+
context = CosmosBufferedChatCompletionContext(
219+
session_id="test_session", user_id="test_user"
220+
)
221+
222+
result = await context.get_data_by_type("invalid_type")
223+
224+
assert result == [] # Expect empty result for invalid type
184225

185226

186227
@pytest.mark.asyncio
187228
async def test_get_plan_by_invalid_session(mock_config, mock_cosmos_client):
188229
"""Test retrieving a plan with an invalid session ID."""
189230
_, mock_container = mock_cosmos_client
190-
mock_container.query_items.return_value = async_iterable([]) # No results
231+
mock_container.query_items.return_value = async_iterable(
232+
[]
233+
) # No results for invalid session
191234

192235
context = CosmosBufferedChatCompletionContext(
193236
session_id="test_session", user_id="test_user"
194237
)
195-
try:
196-
await context.initialize()
197-
result = await context.get_plan_by_session("invalid_session")
198-
assert result is None
199-
finally:
200-
await context.close()
238+
await context.initialize()
239+
result = await context.get_plan_by_session("invalid_session")
240+
241+
assert result is None
201242

202243

203244
@pytest.mark.asyncio
@@ -209,8 +250,17 @@ async def test_delete_item_error_handling(mock_config, mock_cosmos_client):
209250
context = CosmosBufferedChatCompletionContext(
210251
session_id="test_session", user_id="test_user"
211252
)
212-
try:
213-
await context.initialize()
214-
await context.delete_item("test-item", "test-partition")
215-
finally:
216-
await context.close()
253+
await context.initialize()
254+
await context.delete_item(
255+
"test-item", "test-partition"
256+
) # Expect no exception to propagate
257+
258+
259+
@pytest.mark.asyncio
260+
async def test_close_without_initialization(mock_config, mock_cosmos_client):
261+
"""Test close method without prior initialization."""
262+
context = CosmosBufferedChatCompletionContext(
263+
session_id="test_session", user_id="test_user"
264+
)
265+
# Expect no exceptions when closing uninitialized context
266+
await context.close()

0 commit comments

Comments
 (0)