44from azure .cosmos .partition_key import PartitionKey
55from src .backend .context .cosmos_memory import CosmosBufferedChatCompletionContext
66
7- # Set environment variables globally before importing modules
7+ # Mock environment variables
88os .environ ["COSMOSDB_ENDPOINT" ] = "https://mock-endpoint"
99os .environ ["COSMOSDB_KEY" ] = "mock-key"
1010os .environ ["COSMOSDB_DATABASE" ] = "mock-database"
@@ -62,11 +62,14 @@ async def test_initialize(mock_config, mock_cosmos_client):
6262 context = CosmosBufferedChatCompletionContext (
6363 session_id = "test_session" , user_id = "test_user"
6464 )
65- await context .initialize ()
66- mock_client .create_container_if_not_exists .assert_called_once_with (
67- id = "mock-container" , partition_key = PartitionKey (path = "/session_id" )
68- )
69- assert context ._container == mock_container
65+ try :
66+ await context .initialize ()
67+ mock_client .create_container_if_not_exists .assert_called_once_with (
68+ id = "mock-container" , partition_key = PartitionKey (path = "/session_id" )
69+ )
70+ assert context ._container == mock_container
71+ finally :
72+ await context .close ()
7073
7174
7275@pytest .mark .asyncio
@@ -79,12 +82,14 @@ async def test_add_item(mock_config, mock_cosmos_client):
7982 context = CosmosBufferedChatCompletionContext (
8083 session_id = "test_session" , user_id = "test_user"
8184 )
82- await context .initialize ()
83- await context .add_item (mock_item )
84-
85- mock_container .create_item .assert_called_once_with (
86- body = {"id" : "test-item" , "data" : "test-data" }
87- )
85+ try :
86+ await context .initialize ()
87+ await context .add_item (mock_item )
88+ mock_container .create_item .assert_called_once_with (
89+ body = {"id" : "test-item" , "data" : "test-data" }
90+ )
91+ finally :
92+ await context .close ()
8893
8994
9095@pytest .mark .asyncio
@@ -97,12 +102,14 @@ async def test_update_item(mock_config, mock_cosmos_client):
97102 context = CosmosBufferedChatCompletionContext (
98103 session_id = "test_session" , user_id = "test_user"
99104 )
100- await context .initialize ()
101- await context .update_item (mock_item )
102-
103- mock_container .upsert_item .assert_called_once_with (
104- body = {"id" : "test-item" , "data" : "updated-data" }
105- )
105+ try :
106+ await context .initialize ()
107+ await context .update_item (mock_item )
108+ mock_container .upsert_item .assert_called_once_with (
109+ body = {"id" : "test-item" , "data" : "updated-data" }
110+ )
111+ finally :
112+ await context .close ()
106113
107114
108115@pytest .mark .asyncio
@@ -118,15 +125,17 @@ async def test_get_item_by_id(mock_config, mock_cosmos_client):
118125 context = CosmosBufferedChatCompletionContext (
119126 session_id = "test_session" , user_id = "test_user"
120127 )
121- await context .initialize ()
122- result = await context .get_item_by_id (
123- "test-item" , "test-partition" , mock_model_class
124- )
125-
126- assert result == "validated_item"
127- mock_container .read_item .assert_called_once_with (
128- item = "test-item" , partition_key = "test-partition"
129- )
128+ try :
129+ await context .initialize ()
130+ result = await context .get_item_by_id (
131+ "test-item" , "test-partition" , mock_model_class
132+ )
133+ assert result == "validated_item"
134+ mock_container .read_item .assert_called_once_with (
135+ item = "test-item" , partition_key = "test-partition"
136+ )
137+ finally :
138+ await context .close ()
130139
131140
132141@pytest .mark .asyncio
@@ -137,108 +146,58 @@ async def test_delete_item(mock_config, mock_cosmos_client):
137146 context = CosmosBufferedChatCompletionContext (
138147 session_id = "test_session" , user_id = "test_user"
139148 )
140- await context .initialize ()
141- await context .delete_item ("test-item" , "test-partition" )
142-
143- mock_container .delete_item .assert_called_once_with (
144- item = "test-item" , partition_key = "test-partition"
145- )
146-
147-
148- @pytest .mark .asyncio
149- async def test_add_plan (mock_config , mock_cosmos_client ):
150- """Test adding a plan to Cosmos DB."""
151- _ , mock_container = mock_cosmos_client
152- mock_plan = MagicMock ()
153- mock_plan .model_dump .return_value = {"id" : "plan1" , "data" : "plan-data" }
154-
155- context = CosmosBufferedChatCompletionContext (
156- session_id = "test_session" , user_id = "test_user"
157- )
158- await context .initialize ()
159- await context .add_plan (mock_plan )
160-
161- mock_container .create_item .assert_called_once_with (
162- body = {"id" : "plan1" , "data" : "plan-data" }
163- )
164-
165-
166- @pytest .mark .asyncio
167- async def test_update_plan (mock_config , mock_cosmos_client ):
168- """Test updating a plan in Cosmos DB."""
169- _ , mock_container = mock_cosmos_client
170- mock_plan = MagicMock ()
171- mock_plan .model_dump .return_value = {"id" : "plan1" , "data" : "updated-plan-data" }
172-
173- context = CosmosBufferedChatCompletionContext (
174- session_id = "test_session" , user_id = "test_user"
175- )
176- await context .initialize ()
177- await context .update_plan (mock_plan )
178-
179- mock_container .upsert_item .assert_called_once_with (
180- body = {"id" : "plan1" , "data" : "updated-plan-data" }
181- )
149+ try :
150+ await context .initialize ()
151+ await context .delete_item ("test-item" , "test-partition" )
152+ mock_container .delete_item .assert_called_once_with (
153+ item = "test-item" , partition_key = "test-partition"
154+ )
155+ finally :
156+ await context .close ()
182157
183158
184159@pytest .mark .asyncio
185- async def test_add_session (mock_config , mock_cosmos_client ):
186- """Test adding a session to Cosmos DB."""
187- _ , mock_container = mock_cosmos_client
188- mock_session = MagicMock ()
189- mock_session .model_dump .return_value = {"id" : "session1" , "data" : "session-data" }
190-
160+ async def test_close_without_initialization ():
161+ """Test closing the context without prior initialization."""
191162 context = CosmosBufferedChatCompletionContext (
192163 session_id = "test_session" , user_id = "test_user"
193164 )
194- await context .initialize ()
195- await context .add_session (mock_session )
196-
197- mock_container .create_item .assert_called_once_with (
198- body = {"id" : "session1" , "data" : "session-data" }
199- )
165+ try :
166+ await context .close () # Should handle gracefully even if not initialized
167+ except Exception as e :
168+ pytest .fail (f"Unexpected exception during close: { e } " )
200169
201170
202171@pytest .mark .asyncio
203172async def test_initialize_event (mock_config , mock_cosmos_client ):
204- """Test the initialization event is set."""
173+ """Test if the initialization flag is correctly set."""
205174 _ , _ = mock_cosmos_client
206175 context = CosmosBufferedChatCompletionContext (
207176 session_id = "test_session" , user_id = "test_user"
208177 )
209178 assert not context ._initialized .is_set ()
210- await context .initialize ()
211- assert context ._initialized .is_set ()
212-
213-
214- @pytest .mark .asyncio
215- async def test_get_data_by_invalid_type (mock_config , mock_cosmos_client ):
216- """Test querying data with an invalid type."""
217- _ , _ = mock_cosmos_client
218- context = CosmosBufferedChatCompletionContext (
219- session_id = "test_session" , user_id = "test_user"
220- )
221-
222- result = await context .get_data_by_type ("invalid_type" )
223-
224- assert result == [] # Expect empty result for invalid type
179+ try :
180+ await context .initialize ()
181+ assert context ._initialized .is_set ()
182+ finally :
183+ await context .close ()
225184
226185
227186@pytest .mark .asyncio
228187async def test_get_plan_by_invalid_session (mock_config , mock_cosmos_client ):
229188 """Test retrieving a plan with an invalid session ID."""
230189 _ , mock_container = mock_cosmos_client
231- mock_container .query_items .return_value = async_iterable (
232- []
233- ) # No results for invalid session
190+ mock_container .query_items .return_value = async_iterable ([]) # No results
234191
235192 context = CosmosBufferedChatCompletionContext (
236193 session_id = "test_session" , user_id = "test_user"
237194 )
238- await context .initialize ()
239- result = await context .get_plan_by_session ("invalid_session" )
240-
241- assert result is None
195+ try :
196+ await context .initialize ()
197+ result = await context .get_plan_by_session ("invalid_session" )
198+ assert result is None
199+ finally :
200+ await context .close ()
242201
243202
244203@pytest .mark .asyncio
@@ -250,17 +209,8 @@ async def test_delete_item_error_handling(mock_config, mock_cosmos_client):
250209 context = CosmosBufferedChatCompletionContext (
251210 session_id = "test_session" , user_id = "test_user"
252211 )
253- await context .initialize ()
254- await context .delete_item (
255- "test-item" , "test-partition"
256- ) # Expect no exception to propagate
257-
258-
259- @pytest .mark .asyncio
260- async def test_close_without_initialization (mock_config , mock_cosmos_client ):
261- """Test close method without prior initialization."""
262- context = CosmosBufferedChatCompletionContext (
263- session_id = "test_session" , user_id = "test_user"
264- )
265- # Expect no exceptions when closing uninitialized context
266- await context .close ()
212+ try :
213+ await context .initialize ()
214+ await context .delete_item ("test-item" , "test-partition" )
215+ finally :
216+ await context .close ()
0 commit comments