44from azure .cosmos .partition_key import PartitionKey
55from src .backend .context .cosmos_memory import CosmosBufferedChatCompletionContext
66
7- # Mock environment variables
7+ # Set environment variables globally before importing modules
88os .environ ["COSMOSDB_ENDPOINT" ] = "https://mock-endpoint"
99os .environ ["COSMOSDB_KEY" ] = "mock-key"
1010os .environ ["COSMOSDB_DATABASE" ] = "mock-database"
@@ -62,14 +62,11 @@ async def test_initialize(mock_config, mock_cosmos_client):
6262 context = CosmosBufferedChatCompletionContext (
6363 session_id = "test_session" , user_id = "test_user"
6464 )
65- try :
66- await context .initialize ()
67- mock_client .create_container_if_not_exists .assert_called_once_with (
68- id = "mock-container" , partition_key = PartitionKey (path = "/session_id" )
69- )
70- assert context ._container == mock_container
71- finally :
72- await context .close ()
65+ await context .initialize ()
66+ mock_client .create_container_if_not_exists .assert_called_once_with (
67+ id = "mock-container" , partition_key = PartitionKey (path = "/session_id" )
68+ )
69+ assert context ._container == mock_container
7370
7471
7572@pytest .mark .asyncio
@@ -82,14 +79,12 @@ async def test_add_item(mock_config, mock_cosmos_client):
8279 context = CosmosBufferedChatCompletionContext (
8380 session_id = "test_session" , user_id = "test_user"
8481 )
85- try :
86- await context .initialize ()
87- await context .add_item (mock_item )
88- mock_container .create_item .assert_called_once_with (
89- body = {"id" : "test-item" , "data" : "test-data" }
90- )
91- finally :
92- await context .close ()
82+ await context .initialize ()
83+ await context .add_item (mock_item )
84+
85+ mock_container .create_item .assert_called_once_with (
86+ body = {"id" : "test-item" , "data" : "test-data" }
87+ )
9388
9489
9590@pytest .mark .asyncio
@@ -102,14 +97,12 @@ async def test_update_item(mock_config, mock_cosmos_client):
10297 context = CosmosBufferedChatCompletionContext (
10398 session_id = "test_session" , user_id = "test_user"
10499 )
105- try :
106- await context .initialize ()
107- await context .update_item (mock_item )
108- mock_container .upsert_item .assert_called_once_with (
109- body = {"id" : "test-item" , "data" : "updated-data" }
110- )
111- finally :
112- await context .close ()
100+ await context .initialize ()
101+ await context .update_item (mock_item )
102+
103+ mock_container .upsert_item .assert_called_once_with (
104+ body = {"id" : "test-item" , "data" : "updated-data" }
105+ )
113106
114107
115108@pytest .mark .asyncio
@@ -125,17 +118,15 @@ async def test_get_item_by_id(mock_config, mock_cosmos_client):
125118 context = CosmosBufferedChatCompletionContext (
126119 session_id = "test_session" , user_id = "test_user"
127120 )
128- try :
129- await context .initialize ()
130- result = await context .get_item_by_id (
131- "test-item" , "test-partition" , mock_model_class
132- )
133- assert result == "validated_item"
134- mock_container .read_item .assert_called_once_with (
135- item = "test-item" , partition_key = "test-partition"
136- )
137- finally :
138- await context .close ()
121+ await context .initialize ()
122+ result = await context .get_item_by_id (
123+ "test-item" , "test-partition" , mock_model_class
124+ )
125+
126+ assert result == "validated_item"
127+ mock_container .read_item .assert_called_once_with (
128+ item = "test-item" , partition_key = "test-partition"
129+ )
139130
140131
141132@pytest .mark .asyncio
@@ -146,58 +137,108 @@ async def test_delete_item(mock_config, mock_cosmos_client):
146137 context = CosmosBufferedChatCompletionContext (
147138 session_id = "test_session" , user_id = "test_user"
148139 )
149- try :
150- await context .initialize ()
151- await context .delete_item ("test-item" , "test-partition" )
152- mock_container .delete_item .assert_called_once_with (
153- item = "test-item" , partition_key = "test-partition"
154- )
155- finally :
156- await context .close ()
140+ await context .initialize ()
141+ await context .delete_item ("test-item" , "test-partition" )
142+
143+ mock_container .delete_item .assert_called_once_with (
144+ item = "test-item" , partition_key = "test-partition"
145+ )
146+
147+
148+ @pytest .mark .asyncio
149+ async def test_add_plan (mock_config , mock_cosmos_client ):
150+ """Test adding a plan to Cosmos DB."""
151+ _ , mock_container = mock_cosmos_client
152+ mock_plan = MagicMock ()
153+ mock_plan .model_dump .return_value = {"id" : "plan1" , "data" : "plan-data" }
154+
155+ context = CosmosBufferedChatCompletionContext (
156+ session_id = "test_session" , user_id = "test_user"
157+ )
158+ await context .initialize ()
159+ await context .add_plan (mock_plan )
160+
161+ mock_container .create_item .assert_called_once_with (
162+ body = {"id" : "plan1" , "data" : "plan-data" }
163+ )
164+
165+
166+ @pytest .mark .asyncio
167+ async def test_update_plan (mock_config , mock_cosmos_client ):
168+ """Test updating a plan in Cosmos DB."""
169+ _ , mock_container = mock_cosmos_client
170+ mock_plan = MagicMock ()
171+ mock_plan .model_dump .return_value = {"id" : "plan1" , "data" : "updated-plan-data" }
172+
173+ context = CosmosBufferedChatCompletionContext (
174+ session_id = "test_session" , user_id = "test_user"
175+ )
176+ await context .initialize ()
177+ await context .update_plan (mock_plan )
178+
179+ mock_container .upsert_item .assert_called_once_with (
180+ body = {"id" : "plan1" , "data" : "updated-plan-data" }
181+ )
157182
158183
159184@pytest .mark .asyncio
160- async def test_close_without_initialization ():
161- """Test closing the context without prior initialization."""
185+ async def test_add_session (mock_config , mock_cosmos_client ):
186+ """Test adding a session to Cosmos DB."""
187+ _ , mock_container = mock_cosmos_client
188+ mock_session = MagicMock ()
189+ mock_session .model_dump .return_value = {"id" : "session1" , "data" : "session-data" }
190+
162191 context = CosmosBufferedChatCompletionContext (
163192 session_id = "test_session" , user_id = "test_user"
164193 )
165- try :
166- await context .close () # Should handle gracefully even if not initialized
167- except Exception as e :
168- pytest .fail (f"Unexpected exception during close: { e } " )
194+ await context .initialize ()
195+ await context .add_session (mock_session )
196+
197+ mock_container .create_item .assert_called_once_with (
198+ body = {"id" : "session1" , "data" : "session-data" }
199+ )
169200
170201
171202@pytest .mark .asyncio
172203async def test_initialize_event (mock_config , mock_cosmos_client ):
173- """Test if the initialization flag is correctly set."""
204+ """Test the initialization event is set."""
174205 _ , _ = mock_cosmos_client
175206 context = CosmosBufferedChatCompletionContext (
176207 session_id = "test_session" , user_id = "test_user"
177208 )
178209 assert not context ._initialized .is_set ()
179- try :
180- await context .initialize ()
181- assert context ._initialized .is_set ()
182- finally :
183- await context .close ()
210+ await context .initialize ()
211+ assert context ._initialized .is_set ()
212+
213+
214+ @pytest .mark .asyncio
215+ async def test_get_data_by_invalid_type (mock_config , mock_cosmos_client ):
216+ """Test querying data with an invalid type."""
217+ _ , _ = mock_cosmos_client
218+ context = CosmosBufferedChatCompletionContext (
219+ session_id = "test_session" , user_id = "test_user"
220+ )
221+
222+ result = await context .get_data_by_type ("invalid_type" )
223+
224+ assert result == [] # Expect empty result for invalid type
184225
185226
186227@pytest .mark .asyncio
187228async def test_get_plan_by_invalid_session (mock_config , mock_cosmos_client ):
188229 """Test retrieving a plan with an invalid session ID."""
189230 _ , mock_container = mock_cosmos_client
190- mock_container .query_items .return_value = async_iterable ([]) # No results
231+ mock_container .query_items .return_value = async_iterable (
232+ []
233+ ) # No results for invalid session
191234
192235 context = CosmosBufferedChatCompletionContext (
193236 session_id = "test_session" , user_id = "test_user"
194237 )
195- try :
196- await context .initialize ()
197- result = await context .get_plan_by_session ("invalid_session" )
198- assert result is None
199- finally :
200- await context .close ()
238+ await context .initialize ()
239+ result = await context .get_plan_by_session ("invalid_session" )
240+
241+ assert result is None
201242
202243
203244@pytest .mark .asyncio
@@ -209,8 +250,17 @@ async def test_delete_item_error_handling(mock_config, mock_cosmos_client):
209250 context = CosmosBufferedChatCompletionContext (
210251 session_id = "test_session" , user_id = "test_user"
211252 )
212- try :
213- await context .initialize ()
214- await context .delete_item ("test-item" , "test-partition" )
215- finally :
216- await context .close ()
253+ await context .initialize ()
254+ await context .delete_item (
255+ "test-item" , "test-partition"
256+ ) # Expect no exception to propagate
257+
258+
259+ @pytest .mark .asyncio
260+ async def test_close_without_initialization (mock_config , mock_cosmos_client ):
261+ """Test close method without prior initialization."""
262+ context = CosmosBufferedChatCompletionContext (
263+ session_id = "test_session" , user_id = "test_user"
264+ )
265+ # Expect no exceptions when closing uninitialized context
266+ await context .close ()
0 commit comments