1313os .environ ["AZURE_OPENAI_API_VERSION" ] = "2023-01-01"
1414os .environ ["AZURE_OPENAI_ENDPOINT" ] = "https://mock-openai-endpoint"
1515
16+
1617async def async_iterable (mock_items ):
1718 """Helper to create an async iterable."""
1819 for item in mock_items :
1920 yield item
2021
22+
2123@pytest .fixture (autouse = True )
2224def mock_env_variables (monkeypatch ):
2325 """Mock all required environment variables."""
@@ -33,12 +35,6 @@ def mock_env_variables(monkeypatch):
3335 for key , value in env_vars .items ():
3436 monkeypatch .setenv (key , value )
3537
36- @pytest .fixture (autouse = True )
37- def mock_azure_credentials ():
38- """Mock Azure DefaultAzureCredential for all tests."""
39- with patch ("azure.identity.aio.DefaultAzureCredential" ) as mock_cred :
40- mock_cred .return_value .get_token = AsyncMock (return_value = {"token" : "mock-token" })
41- yield
4238
4339@pytest .fixture
4440def mock_cosmos_client ():
@@ -48,6 +44,7 @@ def mock_cosmos_client():
4844 mock_client .create_container_if_not_exists .return_value = mock_container
4945 return mock_client , mock_container
5046
47+
5148@pytest .fixture
5249def mock_config (mock_cosmos_client ):
5350 """Fixture to patch Config with mock Cosmos DB client."""
@@ -57,21 +54,20 @@ def mock_config(mock_cosmos_client):
5754 ), patch ("src.backend.config.Config.COSMOSDB_CONTAINER" , "mock-container" ):
5855 yield
5956
57+
6058@pytest .mark .asyncio
6159async def test_initialize (mock_config , mock_cosmos_client ):
6260 """Test if the Cosmos DB container is initialized correctly."""
6361 mock_client , mock_container = mock_cosmos_client
6462 context = CosmosBufferedChatCompletionContext (
6563 session_id = "test_session" , user_id = "test_user"
6664 )
67- try :
68- await context .initialize ()
69- mock_client .create_container_if_not_exists .assert_called_once_with (
70- id = "mock-container" , partition_key = PartitionKey (path = "/session_id" )
71- )
72- assert context ._container == mock_container
73- finally :
74- await context .close ()
65+ await context .initialize ()
66+ mock_client .create_container_if_not_exists .assert_called_once_with (
67+ id = "mock-container" , partition_key = PartitionKey (path = "/session_id" )
68+ )
69+ assert context ._container == mock_container
70+
7571
7672@pytest .mark .asyncio
7773async def test_add_item (mock_config , mock_cosmos_client ):
@@ -83,14 +79,13 @@ async def test_add_item(mock_config, mock_cosmos_client):
8379 context = CosmosBufferedChatCompletionContext (
8480 session_id = "test_session" , user_id = "test_user"
8581 )
86- try :
87- await context .initialize ()
88- await context .add_item (mock_item )
89- mock_container .create_item .assert_called_once_with (
90- body = {"id" : "test-item" , "data" : "test-data" }
91- )
92- finally :
93- await context .close ()
82+ await context .initialize ()
83+ await context .add_item (mock_item )
84+
85+ mock_container .create_item .assert_called_once_with (
86+ body = {"id" : "test-item" , "data" : "test-data" }
87+ )
88+
9489
9590@pytest .mark .asyncio
9691async def test_update_item (mock_config , mock_cosmos_client ):
@@ -102,14 +97,13 @@ async def test_update_item(mock_config, mock_cosmos_client):
10297 context = CosmosBufferedChatCompletionContext (
10398 session_id = "test_session" , user_id = "test_user"
10499 )
105- try :
106- await context .initialize ()
107- await context .update_item (mock_item )
108- mock_container .upsert_item .assert_called_once_with (
109- body = {"id" : "test-item" , "data" : "updated-data" }
110- )
111- finally :
112- await context .close ()
100+ await context .initialize ()
101+ await context .update_item (mock_item )
102+
103+ mock_container .upsert_item .assert_called_once_with (
104+ body = {"id" : "test-item" , "data" : "updated-data" }
105+ )
106+
113107
114108@pytest .mark .asyncio
115109async def test_get_item_by_id (mock_config , mock_cosmos_client ):
@@ -124,72 +118,128 @@ async def test_get_item_by_id(mock_config, mock_cosmos_client):
124118 context = CosmosBufferedChatCompletionContext (
125119 session_id = "test_session" , user_id = "test_user"
126120 )
127- try :
128- await context .initialize ()
129- result = await context .get_item_by_id (
130- "test-item" , "test-partition" , mock_model_class
131- )
132- assert result == "validated_item"
133- mock_container .read_item .assert_called_once_with (
134- item = "test-item" , partition_key = "test-partition"
135- )
136- finally :
137- await context .close ()
121+ await context .initialize ()
122+ result = await context .get_item_by_id (
123+ "test-item" , "test-partition" , mock_model_class
124+ )
125+
126+ assert result == "validated_item"
127+ mock_container .read_item .assert_called_once_with (
128+ item = "test-item" , partition_key = "test-partition"
129+ )
130+
138131
139132@pytest .mark .asyncio
140133async def test_delete_item (mock_config , mock_cosmos_client ):
141134 """Test deleting an item from Cosmos DB."""
142135 _ , mock_container = mock_cosmos_client
136+
143137 context = CosmosBufferedChatCompletionContext (
144138 session_id = "test_session" , user_id = "test_user"
145139 )
146- try :
147- await context .initialize ()
148- await context .delete_item ("test-item" , "test-partition" )
149- mock_container .delete_item .assert_called_once_with (
150- item = "test-item" , partition_key = "test-partition"
151- )
152- finally :
153- await context .close ()
140+ await context .initialize ()
141+ await context .delete_item ("test-item" , "test-partition" )
142+
143+ mock_container .delete_item .assert_called_once_with (
144+ item = "test-item" , partition_key = "test-partition"
145+ )
146+
154147
155148@pytest .mark .asyncio
156- async def test_get_data_by_invalid_type (mock_config , mock_cosmos_client ):
157- """Test querying data with an invalid type."""
149+ async def test_add_plan (mock_config , mock_cosmos_client ):
150+ """Test adding a plan to Cosmos DB."""
151+ _ , mock_container = mock_cosmos_client
152+ mock_plan = MagicMock ()
153+ mock_plan .model_dump .return_value = {"id" : "plan1" , "data" : "plan-data" }
154+
158155 context = CosmosBufferedChatCompletionContext (
159156 session_id = "test_session" , user_id = "test_user"
160157 )
161- try :
162- result = await context .get_data_by_type ("invalid_type" )
163- assert result == [] # Expect empty result for invalid type
164- finally :
165- await context .close ()
158+ await context .initialize ()
159+ await context .add_plan (mock_plan )
160+
161+ mock_container .create_item .assert_called_once_with (
162+ body = {"id" : "plan1" , "data" : "plan-data" }
163+ )
164+
166165
167166@pytest .mark .asyncio
168- async def test_get_plan_by_invalid_session (mock_config , mock_cosmos_client ):
169- """Test retrieving a plan with an invalid session ID ."""
167+ async def test_update_plan (mock_config , mock_cosmos_client ):
168+ """Test updating a plan in Cosmos DB ."""
170169 _ , mock_container = mock_cosmos_client
171- mock_container .query_items .return_value = async_iterable ([]) # No results
170+ mock_plan = MagicMock ()
171+ mock_plan .model_dump .return_value = {"id" : "plan1" , "data" : "updated-plan-data" }
172172
173173 context = CosmosBufferedChatCompletionContext (
174174 session_id = "test_session" , user_id = "test_user"
175175 )
176- try :
177- await context .initialize ()
178- result = await context .get_plan_by_session ("invalid_session" )
179- assert result is None
180- finally :
181- await context .close ()
176+ await context .initialize ()
177+ await context .update_plan (mock_plan )
178+
179+ mock_container .upsert_item .assert_called_once_with (
180+ body = {"id" : "plan1" , "data" : "updated-plan-data" }
181+ )
182+
182183
183184@pytest .mark .asyncio
184- async def test_close_without_initialization (mock_config ):
185- """Test close method without prior initialization."""
185+ async def test_add_session (mock_config , mock_cosmos_client ):
186+ """Test adding a session to Cosmos DB."""
187+ _ , mock_container = mock_cosmos_client
188+ mock_session = MagicMock ()
189+ mock_session .model_dump .return_value = {"id" : "session1" , "data" : "session-data" }
190+
186191 context = CosmosBufferedChatCompletionContext (
187192 session_id = "test_session" , user_id = "test_user"
188193 )
189- try :
190- await context .close ()
191- except Exception as e :
192- pytest .fail (f"Unexpected exception during close: { e } " )
194+ await context .initialize ()
195+ await context .add_session (mock_session )
196+
197+ mock_container .create_item .assert_called_once_with (
198+ body = {"id" : "session1" , "data" : "session-data" }
199+ )
200+
201+
202+ @pytest .mark .asyncio
203+ async def test_initialize_event (mock_config , mock_cosmos_client ):
204+ """Test the initialization event is set."""
205+ _ , _ = mock_cosmos_client
206+ context = CosmosBufferedChatCompletionContext (
207+ session_id = "test_session" , user_id = "test_user"
208+ )
209+ assert not context ._initialized .is_set ()
210+ await context .initialize ()
211+ assert context ._initialized .is_set ()
212+
213+
214+ @pytest .mark .asyncio
215+ async def test_get_data_by_invalid_type (mock_config , mock_cosmos_client ):
216+ """Test querying data with an invalid type."""
217+ _ , _ = mock_cosmos_client
218+ context = CosmosBufferedChatCompletionContext (
219+ session_id = "test_session" , user_id = "test_user"
220+ )
221+
222+ result = await context .get_data_by_type ("invalid_type" )
223+
224+ assert result == [] # Expect empty result for invalid type
225+
226+
227+ @pytest .mark .asyncio
228+ async def test_get_plan_by_invalid_session (mock_config , mock_cosmos_client ):
229+ """Test retrieving a plan with an invalid session ID."""
230+ _ , mock_container = mock_cosmos_client
231+ mock_container .query_items .return_value = async_iterable (
232+ []
233+ ) # No results for invalid session
234+
235+ context = CosmosBufferedChatCompletionContext (
236+ session_id = "test_session" , user_id = "test_user"
237+ )
238+ await context .initialize ()
239+ result = await context .get_plan_by_session ("invalid_session" )
240+
241+ assert result is None
242+
193243
194244@pytest .mark .asyncio
195245async def test_delete_item_error_handling (mock_config , mock_cosmos_client ):
@@ -200,8 +250,17 @@ async def test_delete_item_error_handling(mock_config, mock_cosmos_client):
200250 context = CosmosBufferedChatCompletionContext (
201251 session_id = "test_session" , user_id = "test_user"
202252 )
203- try :
204- await context .initialize ()
205- await context .delete_item ("test-item" , "test-partition" )
206- finally :
207- await context .close ()
253+ await context .initialize ()
254+ await context .delete_item (
255+ "test-item" , "test-partition"
256+ ) # Expect no exception to propagate
257+
258+
259+ @pytest .mark .asyncio
260+ async def test_close_without_initialization (mock_config , mock_cosmos_client ):
261+ """Test close method without prior initialization."""
262+ context = CosmosBufferedChatCompletionContext (
263+ session_id = "test_session" , user_id = "test_user"
264+ )
265+ # Expect no exceptions when closing uninitialized context
266+ await context .close ()
0 commit comments