@@ -37,6 +37,14 @@ def mock_env_variables(monkeypatch):
3737 monkeypatch .setenv (key , value )
3838
3939
40+ @pytest .fixture
41+ def mock_azure_credentials ():
42+ """Mock Azure DefaultAzureCredential for all tests."""
43+ with patch ("azure.identity.aio.DefaultAzureCredential" ) as mock_cred :
44+ mock_cred .return_value .get_token = AsyncMock (return_value = {"token" : "mock-token" })
45+ yield
46+
47+
4048@pytest .fixture
4149def mock_cosmos_client ():
4250 """Fixture for mocking Cosmos DB client and container."""
@@ -71,13 +79,12 @@ async def test_initialize(mock_config, mock_cosmos_client):
7179
7280
7381@pytest .mark .asyncio
74- async def test_close_without_initialization ():
82+ async def test_close_without_initialization (mock_config ):
7583 """Test closing the context without prior initialization."""
7684 async with CosmosBufferedChatCompletionContext (
7785 session_id = "test_session" , user_id = "test_user"
78- ) as context :
79- # Ensure close is safe even if not explicitly initialized
80- pass
86+ ):
87+ pass # Expect no errors when exiting context
8188
8289
8390@pytest .mark .asyncio
@@ -95,3 +102,116 @@ async def test_add_item(mock_config, mock_cosmos_client):
95102 mock_container .create_item .assert_called_once_with (
96103 body = {"id" : "test-item" , "data" : "test-data" }
97104 )
105+
106+
107+ @pytest .mark .asyncio
108+ async def test_update_item (mock_config , mock_cosmos_client ):
109+ """Test updating an item in Cosmos DB."""
110+ _ , mock_container = mock_cosmos_client
111+ mock_item = MagicMock ()
112+ mock_item .model_dump .return_value = {"id" : "test-item" , "data" : "updated-data" }
113+
114+ async with CosmosBufferedChatCompletionContext (
115+ session_id = "test_session" , user_id = "test_user"
116+ ) as context :
117+ await context .initialize ()
118+ await context .update_item (mock_item )
119+ mock_container .upsert_item .assert_called_once_with (
120+ body = {"id" : "test-item" , "data" : "updated-data" }
121+ )
122+
123+
124+ @pytest .mark .asyncio
125+ async def test_get_item_by_id (mock_config , mock_cosmos_client ):
126+ """Test retrieving an item by ID from Cosmos DB."""
127+ _ , mock_container = mock_cosmos_client
128+ mock_item = {"id" : "test-item" , "data" : "retrieved-data" }
129+ mock_container .read_item .return_value = mock_item
130+
131+ mock_model_class = MagicMock ()
132+ mock_model_class .model_validate .return_value = "validated_item"
133+
134+ async with CosmosBufferedChatCompletionContext (
135+ session_id = "test_session" , user_id = "test_user"
136+ ) as context :
137+ await context .initialize ()
138+ result = await context .get_item_by_id (
139+ "test-item" , "test-partition" , mock_model_class
140+ )
141+
142+ assert result == "validated_item"
143+ mock_container .read_item .assert_called_once_with (
144+ item = "test-item" , partition_key = "test-partition"
145+ )
146+
147+
148+ @pytest .mark .asyncio
149+ async def test_delete_item (mock_config , mock_cosmos_client ):
150+ """Test deleting an item from Cosmos DB."""
151+ _ , mock_container = mock_cosmos_client
152+
153+ async with CosmosBufferedChatCompletionContext (
154+ session_id = "test_session" , user_id = "test_user"
155+ ) as context :
156+ await context .initialize ()
157+ await context .delete_item ("test-item" , "test-partition" )
158+
159+ mock_container .delete_item .assert_called_once_with (
160+ item = "test-item" , partition_key = "test-partition"
161+ )
162+
163+
164+ @pytest .mark .asyncio
165+ async def test_add_plan (mock_config , mock_cosmos_client ):
166+ """Test adding a plan to Cosmos DB."""
167+ _ , mock_container = mock_cosmos_client
168+ mock_plan = MagicMock ()
169+ mock_plan .model_dump .return_value = {"id" : "plan1" , "data" : "plan-data" }
170+
171+ async with CosmosBufferedChatCompletionContext (
172+ session_id = "test_session" , user_id = "test_user"
173+ ) as context :
174+ await context .initialize ()
175+ await context .add_plan (mock_plan )
176+
177+ mock_container .create_item .assert_called_once_with (
178+ body = {"id" : "plan1" , "data" : "plan-data" }
179+ )
180+
181+
182+ @pytest .mark .asyncio
183+ async def test_update_plan (mock_config , mock_cosmos_client ):
184+ """Test updating a plan in Cosmos DB."""
185+ _ , mock_container = mock_cosmos_client
186+ mock_plan = MagicMock ()
187+ mock_plan .model_dump .return_value = {
188+ "id" : "plan1" ,
189+ "data" : "updated-plan-data" ,
190+ }
191+
192+ async with CosmosBufferedChatCompletionContext (
193+ session_id = "test_session" , user_id = "test_user"
194+ ) as context :
195+ await context .initialize ()
196+ await context .update_plan (mock_plan )
197+
198+ mock_container .upsert_item .assert_called_once_with (
199+ body = {"id" : "plan1" , "data" : "updated-plan-data" }
200+ )
201+
202+
203+ @pytest .mark .asyncio
204+ async def test_get_plan_by_invalid_session (mock_config , mock_cosmos_client ):
205+ """Test retrieving a plan with an invalid session ID."""
206+ _ , mock_container = mock_cosmos_client
207+ mock_container .query_items .return_value = async_iterable (
208+ []
209+ ) # No results for invalid session
210+
211+ async with CosmosBufferedChatCompletionContext (
212+ session_id = "test_session" , user_id = "test_user"
213+ ) as context :
214+ await context .initialize ()
215+ result = await context .get_plan_by_session ("invalid_session" )
216+
217+ assert result is None
0 commit comments