1313os .environ ["AZURE_OPENAI_API_VERSION" ] = "2023-01-01"
1414os .environ ["AZURE_OPENAI_ENDPOINT" ] = "https://mock-openai-endpoint"
1515
16-
1716async def async_iterable (mock_items ):
1817 """Helper to create an async iterable."""
1918 for item in mock_items :
2019 yield item
2120
22-
2321@pytest .fixture (autouse = True )
2422def mock_env_variables (monkeypatch ):
2523 """Mock all required environment variables."""
@@ -35,6 +33,12 @@ def mock_env_variables(monkeypatch):
3533 for key , value in env_vars .items ():
3634 monkeypatch .setenv (key , value )
3735
36+ @pytest .fixture (autouse = True )
37+ def mock_azure_credentials ():
38+ """Mock Azure DefaultAzureCredential for all tests."""
39+ with patch ("azure.identity.aio.DefaultAzureCredential" ) as mock_cred :
40+ mock_cred .return_value .get_token = AsyncMock (return_value = {"token" : "mock-token" })
41+ yield
3842
3943@pytest .fixture
4044def mock_cosmos_client ():
@@ -44,7 +48,6 @@ def mock_cosmos_client():
4448 mock_client .create_container_if_not_exists .return_value = mock_container
4549 return mock_client , mock_container
4650
47-
4851@pytest .fixture
4952def mock_config (mock_cosmos_client ):
5053 """Fixture to patch Config with mock Cosmos DB client."""
@@ -54,20 +57,21 @@ def mock_config(mock_cosmos_client):
5457 ), patch ("src.backend.config.Config.COSMOSDB_CONTAINER" , "mock-container" ):
5558 yield
5659
57-
5860@pytest .mark .asyncio
5961async def test_initialize (mock_config , mock_cosmos_client ):
6062 """Test if the Cosmos DB container is initialized correctly."""
6163 mock_client , mock_container = mock_cosmos_client
6264 context = CosmosBufferedChatCompletionContext (
6365 session_id = "test_session" , user_id = "test_user"
6466 )
65- await context .initialize ()
66- mock_client .create_container_if_not_exists .assert_called_once_with (
67- id = "mock-container" , partition_key = PartitionKey (path = "/session_id" )
68- )
69- assert context ._container == mock_container
70-
67+ try :
68+ await context .initialize ()
69+ mock_client .create_container_if_not_exists .assert_called_once_with (
70+ id = "mock-container" , partition_key = PartitionKey (path = "/session_id" )
71+ )
72+ assert context ._container == mock_container
73+ finally :
74+ await context .close ()
7175
7276@pytest .mark .asyncio
7377async def test_add_item (mock_config , mock_cosmos_client ):
@@ -79,13 +83,14 @@ async def test_add_item(mock_config, mock_cosmos_client):
7983 context = CosmosBufferedChatCompletionContext (
8084 session_id = "test_session" , user_id = "test_user"
8185 )
82- await context .initialize ()
83- await context .add_item (mock_item )
84-
85- mock_container .create_item .assert_called_once_with (
86- body = {"id" : "test-item" , "data" : "test-data" }
87- )
88-
86+ try :
87+ await context .initialize ()
88+ await context .add_item (mock_item )
89+ mock_container .create_item .assert_called_once_with (
90+ body = {"id" : "test-item" , "data" : "test-data" }
91+ )
92+ finally :
93+ await context .close ()
8994
9095@pytest .mark .asyncio
9196async def test_update_item (mock_config , mock_cosmos_client ):
@@ -97,13 +102,14 @@ async def test_update_item(mock_config, mock_cosmos_client):
97102 context = CosmosBufferedChatCompletionContext (
98103 session_id = "test_session" , user_id = "test_user"
99104 )
100- await context .initialize ()
101- await context .update_item (mock_item )
102-
103- mock_container .upsert_item .assert_called_once_with (
104- body = {"id" : "test-item" , "data" : "updated-data" }
105- )
106-
105+ try :
106+ await context .initialize ()
107+ await context .update_item (mock_item )
108+ mock_container .upsert_item .assert_called_once_with (
109+ body = {"id" : "test-item" , "data" : "updated-data" }
110+ )
111+ finally :
112+ await context .close ()
107113
108114@pytest .mark .asyncio
109115async def test_get_item_by_id (mock_config , mock_cosmos_client ):
@@ -118,128 +124,72 @@ async def test_get_item_by_id(mock_config, mock_cosmos_client):
118124 context = CosmosBufferedChatCompletionContext (
119125 session_id = "test_session" , user_id = "test_user"
120126 )
121- await context .initialize ()
122- result = await context .get_item_by_id (
123- "test-item" , "test-partition" , mock_model_class
124- )
125-
126- assert result == "validated_item"
127- mock_container .read_item .assert_called_once_with (
128- item = "test-item" , partition_key = "test-partition"
129- )
130-
127+ try :
128+ await context .initialize ()
129+ result = await context .get_item_by_id (
130+ "test-item" , "test-partition" , mock_model_class
131+ )
132+ assert result == "validated_item"
133+ mock_container .read_item .assert_called_once_with (
134+ item = "test-item" , partition_key = "test-partition"
135+ )
136+ finally :
137+ await context .close ()
131138
132139@pytest .mark .asyncio
133140async def test_delete_item (mock_config , mock_cosmos_client ):
134141 """Test deleting an item from Cosmos DB."""
135142 _ , mock_container = mock_cosmos_client
136-
137- context = CosmosBufferedChatCompletionContext (
138- session_id = "test_session" , user_id = "test_user"
139- )
140- await context .initialize ()
141- await context .delete_item ("test-item" , "test-partition" )
142-
143- mock_container .delete_item .assert_called_once_with (
144- item = "test-item" , partition_key = "test-partition"
145- )
146-
147-
148- @pytest .mark .asyncio
149- async def test_add_plan (mock_config , mock_cosmos_client ):
150- """Test adding a plan to Cosmos DB."""
151- _ , mock_container = mock_cosmos_client
152- mock_plan = MagicMock ()
153- mock_plan .model_dump .return_value = {"id" : "plan1" , "data" : "plan-data" }
154-
155- context = CosmosBufferedChatCompletionContext (
156- session_id = "test_session" , user_id = "test_user"
157- )
158- await context .initialize ()
159- await context .add_plan (mock_plan )
160-
161- mock_container .create_item .assert_called_once_with (
162- body = {"id" : "plan1" , "data" : "plan-data" }
163- )
164-
165-
166- @pytest .mark .asyncio
167- async def test_update_plan (mock_config , mock_cosmos_client ):
168- """Test updating a plan in Cosmos DB."""
169- _ , mock_container = mock_cosmos_client
170- mock_plan = MagicMock ()
171- mock_plan .model_dump .return_value = {"id" : "plan1" , "data" : "updated-plan-data" }
172-
173143 context = CosmosBufferedChatCompletionContext (
174144 session_id = "test_session" , user_id = "test_user"
175145 )
176- await context .initialize ()
177- await context .update_plan (mock_plan )
178-
179- mock_container .upsert_item .assert_called_once_with (
180- body = {"id" : "plan1" , "data" : "updated-plan-data" }
181- )
182-
183-
184- @pytest .mark .asyncio
185- async def test_add_session (mock_config , mock_cosmos_client ):
186- """Test adding a session to Cosmos DB."""
187- _ , mock_container = mock_cosmos_client
188- mock_session = MagicMock ()
189- mock_session .model_dump .return_value = {"id" : "session1" , "data" : "session-data" }
190-
191- context = CosmosBufferedChatCompletionContext (
192- session_id = "test_session" , user_id = "test_user"
193- )
194- await context .initialize ()
195- await context .add_session (mock_session )
196-
197- mock_container .create_item .assert_called_once_with (
198- body = {"id" : "session1" , "data" : "session-data" }
199- )
200-
201-
202- @pytest .mark .asyncio
203- async def test_initialize_event (mock_config , mock_cosmos_client ):
204- """Test the initialization event is set."""
205- _ , _ = mock_cosmos_client
206- context = CosmosBufferedChatCompletionContext (
207- session_id = "test_session" , user_id = "test_user"
208- )
209- assert not context ._initialized .is_set ()
210- await context .initialize ()
211- assert context ._initialized .is_set ()
212-
146+ try :
147+ await context .initialize ()
148+ await context .delete_item ("test-item" , "test-partition" )
149+ mock_container .delete_item .assert_called_once_with (
150+ item = "test-item" , partition_key = "test-partition"
151+ )
152+ finally :
153+ await context .close ()
213154
214155@pytest .mark .asyncio
215156async def test_get_data_by_invalid_type (mock_config , mock_cosmos_client ):
216157 """Test querying data with an invalid type."""
217- _ , _ = mock_cosmos_client
218158 context = CosmosBufferedChatCompletionContext (
219159 session_id = "test_session" , user_id = "test_user"
220160 )
221-
222- result = await context .get_data_by_type ("invalid_type" )
223-
224- assert result == [] # Expect empty result for invalid type
225-
161+ try :
162+ result = await context .get_data_by_type ("invalid_type" )
163+ assert result == [] # Expect empty result for invalid type
164+ finally :
165+ await context . close ()
226166
227167@pytest .mark .asyncio
228168async def test_get_plan_by_invalid_session (mock_config , mock_cosmos_client ):
229169 """Test retrieving a plan with an invalid session ID."""
230170 _ , mock_container = mock_cosmos_client
231- mock_container .query_items .return_value = async_iterable (
232- []
233- ) # No results for invalid session
171+ mock_container .query_items .return_value = async_iterable ([]) # No results
234172
235173 context = CosmosBufferedChatCompletionContext (
236174 session_id = "test_session" , user_id = "test_user"
237175 )
238- await context .initialize ()
239- result = await context .get_plan_by_session ("invalid_session" )
240-
241- assert result is None
176+ try :
177+ await context .initialize ()
178+ result = await context .get_plan_by_session ("invalid_session" )
179+ assert result is None
180+ finally :
181+ await context .close ()
242182
183+ @pytest .mark .asyncio
184+ async def test_close_without_initialization (mock_config ):
185+ """Test close method without prior initialization."""
186+ context = CosmosBufferedChatCompletionContext (
187+ session_id = "test_session" , user_id = "test_user"
188+ )
189+ try :
190+ await context .close ()
191+ except Exception as e :
192+ pytest .fail (f"Unexpected exception during close: { e } " )
243193
244194@pytest .mark .asyncio
245195async def test_delete_item_error_handling (mock_config , mock_cosmos_client ):
@@ -250,17 +200,8 @@ async def test_delete_item_error_handling(mock_config, mock_cosmos_client):
250200 context = CosmosBufferedChatCompletionContext (
251201 session_id = "test_session" , user_id = "test_user"
252202 )
253- await context .initialize ()
254- await context .delete_item (
255- "test-item" , "test-partition"
256- ) # Expect no exception to propagate
257-
258-
259- @pytest .mark .asyncio
260- async def test_close_without_initialization (mock_config , mock_cosmos_client ):
261- """Test close method without prior initialization."""
262- context = CosmosBufferedChatCompletionContext (
263- session_id = "test_session" , user_id = "test_user"
264- )
265- # Expect no exceptions when closing uninitialized context
266- await context .close ()
203+ try :
204+ await context .initialize ()
205+ await context .delete_item ("test-item" , "test-partition" )
206+ finally :
207+ await context .close ()
0 commit comments