1
1
"""
2
2
Tests for the LoggingWorker class to ensure graceful shutdown handling.
3
3
"""
4
+
4
5
import asyncio
5
6
import contextvars
6
- import pytest
7
7
from unittest .mock import AsyncMock , patch
8
8
9
+ import pytest
10
+
9
11
from litellm .litellm_core_utils .logging_worker import LoggingWorker
10
12
11
13
@@ -21,7 +23,9 @@ def logging_worker(self):
21
23
async def test_graceful_shutdown_with_clear_queue (self , logging_worker ):
22
24
"""Test that cancellation triggers clear_queue to prevent 'never awaited' warnings."""
23
25
# Mock the clear_queue method to verify it's called during cancellation
24
- with patch .object (logging_worker , "clear_queue" , new_callable = AsyncMock ) as mock_clear_queue :
26
+ with patch .object (
27
+ logging_worker , "clear_queue" , new_callable = AsyncMock
28
+ ) as mock_clear_queue :
25
29
# Start the worker
26
30
logging_worker .start ()
27
31
@@ -64,7 +68,9 @@ async def test_clear_queue_processes_remaining_items(self, logging_worker):
64
68
async def test_worker_handles_cancellation_gracefully (self , logging_worker ):
65
69
"""Test that the worker handles cancellation without throwing exceptions."""
66
70
# Mock verbose_logger to capture debug messages
67
- with patch ("litellm.litellm_core_utils.logging_worker.verbose_logger" ) as mock_logger :
71
+ with patch (
72
+ "litellm.litellm_core_utils.logging_worker.verbose_logger"
73
+ ) as mock_logger :
68
74
# Start the worker
69
75
logging_worker .start ()
70
76
@@ -131,90 +137,133 @@ async def test_queue_full_handling(self, logging_worker):
131
137
small_worker ._ensure_queue ()
132
138
133
139
# Mock verbose_logger to capture exception messages
134
- with patch ("litellm.litellm_core_utils.logging_worker.verbose_logger" ) as mock_logger :
140
+ with patch (
141
+ "litellm.litellm_core_utils.logging_worker.verbose_logger"
142
+ ) as mock_logger :
135
143
# Fill the queue beyond capacity
136
144
mock_coro = AsyncMock ()
137
145
for _ in range (5 ): # More than max_queue_size of 2
138
146
small_worker .enqueue (mock_coro ())
139
147
140
148
# Should have logged queue full exceptions
141
- exception_calls = [call for call in mock_logger .exception .call_args_list if "queue is full" in str (call )]
149
+ exception_calls = [
150
+ call
151
+ for call in mock_logger .exception .call_args_list
152
+ if "queue is full" in str (call )
153
+ ]
142
154
assert len (exception_calls ) > 0
143
155
144
156
@pytest .mark .asyncio
145
157
async def test_context_propagation (self , logging_worker ):
146
158
"""Test that enqueued tasks execute in their original context."""
147
159
# Create a context variable for testing
148
- test_context_var : contextvars .ContextVar [str ] = contextvars .ContextVar ('test_context_var' )
149
-
150
- # Track results from multiple tasks
160
+ test_context_var : contextvars .ContextVar [str ] = contextvars .ContextVar (
161
+ "test_context_var"
162
+ )
163
+
164
+ # Track results from multiple tasks using asyncio.Event for synchronization
151
165
task_results = []
152
-
166
+ completion_events = {}
167
+
153
168
async def test_task (task_id : str ):
154
169
"""A test coroutine that checks if it can access the context variable."""
155
- # Sleep a bit to simulate real work and ensure context persists
156
- await asyncio .sleep (0.1 )
157
-
158
170
try :
159
171
# Try to get the context variable value
160
172
value = test_context_var .get ()
161
- task_results .append ({
162
- 'task_id' : task_id ,
163
- 'context_value' : value ,
164
- 'context_accessible' : True
165
- })
173
+ task_results .append (
174
+ {
175
+ "task_id" : task_id ,
176
+ "context_value" : value ,
177
+ "context_accessible" : True ,
178
+ }
179
+ )
166
180
except LookupError :
167
181
# Context variable not found
168
- task_results .append ({
169
- 'task_id' : task_id ,
170
- 'context_accessible' : False ,
171
- 'context_value' : None
172
- })
173
-
182
+ task_results .append (
183
+ {
184
+ "task_id" : task_id ,
185
+ "context_accessible" : False ,
186
+ "context_value" : None ,
187
+ }
188
+ )
189
+ finally :
190
+ # Signal that this task is complete
191
+ completion_events [task_id ].set ()
192
+
193
+ # Create completion events for each task
194
+ completion_events ["task_1" ] = asyncio .Event ()
195
+ completion_events ["task_2" ] = asyncio .Event ()
196
+ completion_events ["task_3" ] = asyncio .Event ()
197
+
174
198
# Start the logging worker
175
199
logging_worker .start ()
176
-
200
+
201
+ # Give the worker a moment to start
202
+ await asyncio .sleep (0.1 )
203
+
177
204
# Create two separate contexts and enqueue tasks from each
178
-
205
+
179
206
# Context 1: Set context var to "context_1"
180
207
ctx1 = contextvars .copy_context ()
181
208
ctx1 .run (test_context_var .set , "context_1" )
182
209
ctx1 .run (logging_worker .enqueue , test_task ("task_1" ))
183
-
184
- # Context 2: Set context var to "context_2"
210
+
211
+ # Context 2: Set context var to "context_2"
185
212
ctx2 = contextvars .copy_context ()
186
213
ctx2 .run (test_context_var .set , "context_2" )
187
214
ctx2 .run (logging_worker .enqueue , test_task ("task_2" ))
188
-
215
+
189
216
# Context 3: No context variable set (should get LookupError)
190
217
ctx3 = contextvars .copy_context ()
191
218
ctx3 .run (logging_worker .enqueue , test_task ("task_3" ))
192
-
193
- # Wait for all tasks to be processed
194
- await asyncio .sleep (0.5 )
195
-
219
+
220
+ # Wait for all tasks to complete with a reasonable timeout
221
+ try :
222
+ await asyncio .wait_for (
223
+ asyncio .gather (
224
+ completion_events ["task_1" ].wait (),
225
+ completion_events ["task_2" ].wait (),
226
+ completion_events ["task_3" ].wait (),
227
+ ),
228
+ timeout = 5.0 ,
229
+ )
230
+ except asyncio .TimeoutError :
231
+ pytest .fail ("Tasks did not complete within timeout" )
232
+
196
233
# Stop the worker
197
234
await logging_worker .stop ()
198
-
235
+
199
236
# Sort results by task_id for consistent testing
200
- task_results .sort (key = lambda x : x [' task_id' ])
201
-
237
+ task_results .sort (key = lambda x : x [" task_id" ])
238
+
202
239
# Verify that each task saw its own context
203
- assert len (task_results ) == 3 , f"Expected 3 results, got { len (task_results )} "
204
-
240
+ assert (
241
+ len (task_results ) == 3
242
+ ), f"Expected 3 results, got { len (task_results )} : { task_results } "
243
+
205
244
# Task 1 should see "context_1"
206
- task1_result = next ((r for r in task_results if r [' task_id' ] == ' task_1' ), None )
245
+ task1_result = next ((r for r in task_results if r [" task_id" ] == " task_1" ), None )
207
246
assert task1_result is not None , "Task 1 result not found"
208
- assert task1_result ['context_accessible' ] is True , "Task 1 should have access to context variable"
209
- assert task1_result ['context_value' ] == "context_1" , f"Task 1 should see 'context_1', got: { task1_result ['context_value' ]} "
210
-
247
+ assert (
248
+ task1_result ["context_accessible" ] is True
249
+ ), "Task 1 should have access to context variable"
250
+ assert (
251
+ task1_result ["context_value" ] == "context_1"
252
+ ), f"Task 1 should see 'context_1', got: { task1_result ['context_value' ]} "
253
+
211
254
# Task 2 should see "context_2"
212
- task2_result = next ((r for r in task_results if r [' task_id' ] == ' task_2' ), None )
255
+ task2_result = next ((r for r in task_results if r [" task_id" ] == " task_2" ), None )
213
256
assert task2_result is not None , "Task 2 result not found"
214
- assert task2_result ['context_accessible' ] is True , "Task 2 should have access to context variable"
215
- assert task2_result ['context_value' ] == "context_2" , f"Task 2 should see 'context_2', got: { task2_result ['context_value' ]} "
216
-
257
+ assert (
258
+ task2_result ["context_accessible" ] is True
259
+ ), "Task 2 should have access to context variable"
260
+ assert (
261
+ task2_result ["context_value" ] == "context_2"
262
+ ), f"Task 2 should see 'context_2', got: { task2_result ['context_value' ]} "
263
+
217
264
# Task 3 should not have access to the context variable
218
- task3_result = next ((r for r in task_results if r [' task_id' ] == ' task_3' ), None )
265
+ task3_result = next ((r for r in task_results if r [" task_id" ] == " task_3" ), None )
219
266
assert task3_result is not None , "Task 3 result not found"
220
- assert task3_result ['context_accessible' ] is False , "Task 3 should not have access to context variable"
267
+ assert (
268
+ task3_result ["context_accessible" ] is False
269
+ ), "Task 3 should not have access to context variable"
0 commit comments