|
4 | 4 | from azure.cosmos.partition_key import PartitionKey |
5 | 5 | from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext |
6 | 6 |
|
7 | | -# Set environment variables globally before importing modules |
| 7 | + |
| 8 | +# Mock environment variables |
8 | 9 | os.environ["COSMOSDB_ENDPOINT"] = "https://mock-endpoint" |
9 | 10 | os.environ["COSMOSDB_KEY"] = "mock-key" |
10 | 11 | os.environ["COSMOSDB_DATABASE"] = "mock-database" |
@@ -59,208 +60,38 @@ def mock_config(mock_cosmos_client): |
59 | 60 | async def test_initialize(mock_config, mock_cosmos_client): |
60 | 61 | """Test if the Cosmos DB container is initialized correctly.""" |
61 | 62 | mock_client, mock_container = mock_cosmos_client |
62 | | - context = CosmosBufferedChatCompletionContext( |
| 63 | + async with CosmosBufferedChatCompletionContext( |
63 | 64 | session_id="test_session", user_id="test_user" |
64 | | - ) |
65 | | - await context.initialize() |
66 | | - mock_client.create_container_if_not_exists.assert_called_once_with( |
67 | | - id="mock-container", partition_key=PartitionKey(path="/session_id") |
68 | | - ) |
69 | | - assert context._container == mock_container |
| 65 | + ) as context: |
| 66 | + await context.initialize() |
| 67 | + mock_client.create_container_if_not_exists.assert_called_once_with( |
| 68 | + id="mock-container", partition_key=PartitionKey(path="/session_id") |
| 69 | + ) |
| 70 | + assert context._container == mock_container |
70 | 71 |
|
71 | 72 |
|
72 | 73 | @pytest.mark.asyncio |
73 | | -async def test_add_item(mock_config, mock_cosmos_client): |
74 | | - """Test adding an item to Cosmos DB.""" |
75 | | - _, mock_container = mock_cosmos_client |
76 | | - mock_item = MagicMock() |
77 | | - mock_item.model_dump.return_value = {"id": "test-item", "data": "test-data"} |
78 | | - |
79 | | - context = CosmosBufferedChatCompletionContext( |
| 74 | +async def test_close_without_initialization(): |
| 75 | + """Test closing the context without prior initialization.""" |
| 76 | + async with CosmosBufferedChatCompletionContext( |
80 | 77 | session_id="test_session", user_id="test_user" |
81 | | - ) |
82 | | - await context.initialize() |
83 | | - await context.add_item(mock_item) |
84 | | - |
85 | | - mock_container.create_item.assert_called_once_with( |
86 | | - body={"id": "test-item", "data": "test-data"} |
87 | | - ) |
| 78 | + ) as context: |
| 79 | + # Ensure close is safe even if not explicitly initialized |
| 80 | + pass |
88 | 81 |
|
89 | 82 |
|
90 | 83 | @pytest.mark.asyncio |
91 | | -async def test_update_item(mock_config, mock_cosmos_client): |
92 | | - """Test updating an item in Cosmos DB.""" |
| 84 | +async def test_add_item(mock_config, mock_cosmos_client): |
| 85 | + """Test adding an item to Cosmos DB.""" |
93 | 86 | _, mock_container = mock_cosmos_client |
94 | 87 | mock_item = MagicMock() |
95 | | - mock_item.model_dump.return_value = {"id": "test-item", "data": "updated-data"} |
96 | | - |
97 | | - context = CosmosBufferedChatCompletionContext( |
98 | | - session_id="test_session", user_id="test_user" |
99 | | - ) |
100 | | - await context.initialize() |
101 | | - await context.update_item(mock_item) |
102 | | - |
103 | | - mock_container.upsert_item.assert_called_once_with( |
104 | | - body={"id": "test-item", "data": "updated-data"} |
105 | | - ) |
106 | | - |
107 | | - |
108 | | -@pytest.mark.asyncio |
109 | | -async def test_get_item_by_id(mock_config, mock_cosmos_client): |
110 | | - """Test retrieving an item by ID from Cosmos DB.""" |
111 | | - _, mock_container = mock_cosmos_client |
112 | | - mock_item = {"id": "test-item", "data": "retrieved-data"} |
113 | | - mock_container.read_item.return_value = mock_item |
114 | | - |
115 | | - mock_model_class = MagicMock() |
116 | | - mock_model_class.model_validate.return_value = "validated_item" |
117 | | - |
118 | | - context = CosmosBufferedChatCompletionContext( |
119 | | - session_id="test_session", user_id="test_user" |
120 | | - ) |
121 | | - await context.initialize() |
122 | | - result = await context.get_item_by_id( |
123 | | - "test-item", "test-partition", mock_model_class |
124 | | - ) |
125 | | - |
126 | | - assert result == "validated_item" |
127 | | - mock_container.read_item.assert_called_once_with( |
128 | | - item="test-item", partition_key="test-partition" |
129 | | - ) |
130 | | - |
131 | | - |
132 | | -@pytest.mark.asyncio |
133 | | -async def test_delete_item(mock_config, mock_cosmos_client): |
134 | | - """Test deleting an item from Cosmos DB.""" |
135 | | - _, mock_container = mock_cosmos_client |
136 | | - |
137 | | - context = CosmosBufferedChatCompletionContext( |
138 | | - session_id="test_session", user_id="test_user" |
139 | | - ) |
140 | | - await context.initialize() |
141 | | - await context.delete_item("test-item", "test-partition") |
142 | | - |
143 | | - mock_container.delete_item.assert_called_once_with( |
144 | | - item="test-item", partition_key="test-partition" |
145 | | - ) |
146 | | - |
147 | | - |
148 | | -@pytest.mark.asyncio |
149 | | -async def test_add_plan(mock_config, mock_cosmos_client): |
150 | | - """Test adding a plan to Cosmos DB.""" |
151 | | - _, mock_container = mock_cosmos_client |
152 | | - mock_plan = MagicMock() |
153 | | - mock_plan.model_dump.return_value = {"id": "plan1", "data": "plan-data"} |
154 | | - |
155 | | - context = CosmosBufferedChatCompletionContext( |
156 | | - session_id="test_session", user_id="test_user" |
157 | | - ) |
158 | | - await context.initialize() |
159 | | - await context.add_plan(mock_plan) |
160 | | - |
161 | | - mock_container.create_item.assert_called_once_with( |
162 | | - body={"id": "plan1", "data": "plan-data"} |
163 | | - ) |
164 | | - |
165 | | - |
166 | | -@pytest.mark.asyncio |
167 | | -async def test_update_plan(mock_config, mock_cosmos_client): |
168 | | - """Test updating a plan in Cosmos DB.""" |
169 | | - _, mock_container = mock_cosmos_client |
170 | | - mock_plan = MagicMock() |
171 | | - mock_plan.model_dump.return_value = {"id": "plan1", "data": "updated-plan-data"} |
172 | | - |
173 | | - context = CosmosBufferedChatCompletionContext( |
174 | | - session_id="test_session", user_id="test_user" |
175 | | - ) |
176 | | - await context.initialize() |
177 | | - await context.update_plan(mock_plan) |
178 | | - |
179 | | - mock_container.upsert_item.assert_called_once_with( |
180 | | - body={"id": "plan1", "data": "updated-plan-data"} |
181 | | - ) |
182 | | - |
183 | | - |
184 | | -@pytest.mark.asyncio |
185 | | -async def test_add_session(mock_config, mock_cosmos_client): |
186 | | - """Test adding a session to Cosmos DB.""" |
187 | | - _, mock_container = mock_cosmos_client |
188 | | - mock_session = MagicMock() |
189 | | - mock_session.model_dump.return_value = {"id": "session1", "data": "session-data"} |
190 | | - |
191 | | - context = CosmosBufferedChatCompletionContext( |
192 | | - session_id="test_session", user_id="test_user" |
193 | | - ) |
194 | | - await context.initialize() |
195 | | - await context.add_session(mock_session) |
196 | | - |
197 | | - mock_container.create_item.assert_called_once_with( |
198 | | - body={"id": "session1", "data": "session-data"} |
199 | | - ) |
200 | | - |
201 | | - |
202 | | -@pytest.mark.asyncio |
203 | | -async def test_initialize_event(mock_config, mock_cosmos_client): |
204 | | - """Test the initialization event is set.""" |
205 | | - _, _ = mock_cosmos_client |
206 | | - context = CosmosBufferedChatCompletionContext( |
207 | | - session_id="test_session", user_id="test_user" |
208 | | - ) |
209 | | - assert not context._initialized.is_set() |
210 | | - await context.initialize() |
211 | | - assert context._initialized.is_set() |
212 | | - |
213 | | - |
214 | | -@pytest.mark.asyncio |
215 | | -async def test_get_data_by_invalid_type(mock_config, mock_cosmos_client): |
216 | | - """Test querying data with an invalid type.""" |
217 | | - _, _ = mock_cosmos_client |
218 | | - context = CosmosBufferedChatCompletionContext( |
219 | | - session_id="test_session", user_id="test_user" |
220 | | - ) |
221 | | - |
222 | | - result = await context.get_data_by_type("invalid_type") |
223 | | - |
224 | | - assert result == [] # Expect empty result for invalid type |
225 | | - |
226 | | - |
227 | | -@pytest.mark.asyncio |
228 | | -async def test_get_plan_by_invalid_session(mock_config, mock_cosmos_client): |
229 | | - """Test retrieving a plan with an invalid session ID.""" |
230 | | - _, mock_container = mock_cosmos_client |
231 | | - mock_container.query_items.return_value = async_iterable( |
232 | | - [] |
233 | | - ) # No results for invalid session |
234 | | - |
235 | | - context = CosmosBufferedChatCompletionContext( |
236 | | - session_id="test_session", user_id="test_user" |
237 | | - ) |
238 | | - await context.initialize() |
239 | | - result = await context.get_plan_by_session("invalid_session") |
240 | | - |
241 | | - assert result is None |
242 | | - |
243 | | - |
244 | | -@pytest.mark.asyncio |
245 | | -async def test_delete_item_error_handling(mock_config, mock_cosmos_client): |
246 | | - """Test error handling when deleting an item.""" |
247 | | - _, mock_container = mock_cosmos_client |
248 | | - mock_container.delete_item.side_effect = Exception("Delete error") |
249 | | - |
250 | | - context = CosmosBufferedChatCompletionContext( |
251 | | - session_id="test_session", user_id="test_user" |
252 | | - ) |
253 | | - await context.initialize() |
254 | | - await context.delete_item( |
255 | | - "test-item", "test-partition" |
256 | | - ) # Expect no exception to propagate |
257 | | - |
| 88 | + mock_item.model_dump.return_value = {"id": "test-item", "data": "test-data"} |
258 | 89 |
|
259 | | -@pytest.mark.asyncio |
260 | | -async def test_close_without_initialization(mock_config, mock_cosmos_client): |
261 | | - """Test close method without prior initialization.""" |
262 | | - context = CosmosBufferedChatCompletionContext( |
| 90 | + async with CosmosBufferedChatCompletionContext( |
263 | 91 | session_id="test_session", user_id="test_user" |
264 | | - ) |
265 | | - # Expect no exceptions when closing uninitialized context |
266 | | - await context.close() |
| 92 | + ) as context: |
| 93 | + await context.initialize() |
| 94 | + await context.add_item(mock_item) |
| 95 | + mock_container.create_item.assert_called_once_with( |
| 96 | + body={"id": "test-item", "data": "test-data"} |
| 97 | + ) |
0 commit comments