Skip to content

Commit 03f2be1

Browse files
fix: fix race conditions
1 parent 2c6481f commit 03f2be1

File tree

1 file changed

+96
-47
lines changed

1 file changed

+96
-47
lines changed

tests/test_litellm/litellm_core_utils/test_logging_worker.py

Lines changed: 96 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""
22
Tests for the LoggingWorker class to ensure graceful shutdown handling.
33
"""
4+
45
import asyncio
56
import contextvars
6-
import pytest
77
from unittest.mock import AsyncMock, patch
88

9+
import pytest
10+
911
from litellm.litellm_core_utils.logging_worker import LoggingWorker
1012

1113

@@ -21,7 +23,9 @@ def logging_worker(self):
2123
async def test_graceful_shutdown_with_clear_queue(self, logging_worker):
2224
"""Test that cancellation triggers clear_queue to prevent 'never awaited' warnings."""
2325
# Mock the clear_queue method to verify it's called during cancellation
24-
with patch.object(logging_worker, "clear_queue", new_callable=AsyncMock) as mock_clear_queue:
26+
with patch.object(
27+
logging_worker, "clear_queue", new_callable=AsyncMock
28+
) as mock_clear_queue:
2529
# Start the worker
2630
logging_worker.start()
2731

@@ -64,7 +68,9 @@ async def test_clear_queue_processes_remaining_items(self, logging_worker):
6468
async def test_worker_handles_cancellation_gracefully(self, logging_worker):
6569
"""Test that the worker handles cancellation without throwing exceptions."""
6670
# Mock verbose_logger to capture debug messages
67-
with patch("litellm.litellm_core_utils.logging_worker.verbose_logger") as mock_logger:
71+
with patch(
72+
"litellm.litellm_core_utils.logging_worker.verbose_logger"
73+
) as mock_logger:
6874
# Start the worker
6975
logging_worker.start()
7076

@@ -131,90 +137,133 @@ async def test_queue_full_handling(self, logging_worker):
131137
small_worker._ensure_queue()
132138

133139
# Mock verbose_logger to capture exception messages
134-
with patch("litellm.litellm_core_utils.logging_worker.verbose_logger") as mock_logger:
140+
with patch(
141+
"litellm.litellm_core_utils.logging_worker.verbose_logger"
142+
) as mock_logger:
135143
# Fill the queue beyond capacity
136144
mock_coro = AsyncMock()
137145
for _ in range(5): # More than max_queue_size of 2
138146
small_worker.enqueue(mock_coro())
139147

140148
# Should have logged queue full exceptions
141-
exception_calls = [call for call in mock_logger.exception.call_args_list if "queue is full" in str(call)]
149+
exception_calls = [
150+
call
151+
for call in mock_logger.exception.call_args_list
152+
if "queue is full" in str(call)
153+
]
142154
assert len(exception_calls) > 0
143155

144156
@pytest.mark.asyncio
145157
async def test_context_propagation(self, logging_worker):
146158
"""Test that enqueued tasks execute in their original context."""
147159
# Create a context variable for testing
148-
test_context_var: contextvars.ContextVar[str] = contextvars.ContextVar('test_context_var')
149-
150-
# Track results from multiple tasks
160+
test_context_var: contextvars.ContextVar[str] = contextvars.ContextVar(
161+
"test_context_var"
162+
)
163+
164+
# Track results from multiple tasks using asyncio.Event for synchronization
151165
task_results = []
152-
166+
completion_events = {}
167+
153168
async def test_task(task_id: str):
154169
"""A test coroutine that checks if it can access the context variable."""
155-
# Sleep a bit to simulate real work and ensure context persists
156-
await asyncio.sleep(0.1)
157-
158170
try:
159171
# Try to get the context variable value
160172
value = test_context_var.get()
161-
task_results.append({
162-
'task_id': task_id,
163-
'context_value': value,
164-
'context_accessible': True
165-
})
173+
task_results.append(
174+
{
175+
"task_id": task_id,
176+
"context_value": value,
177+
"context_accessible": True,
178+
}
179+
)
166180
except LookupError:
167181
# Context variable not found
168-
task_results.append({
169-
'task_id': task_id,
170-
'context_accessible': False,
171-
'context_value': None
172-
})
173-
182+
task_results.append(
183+
{
184+
"task_id": task_id,
185+
"context_accessible": False,
186+
"context_value": None,
187+
}
188+
)
189+
finally:
190+
# Signal that this task is complete
191+
completion_events[task_id].set()
192+
193+
# Create completion events for each task
194+
completion_events["task_1"] = asyncio.Event()
195+
completion_events["task_2"] = asyncio.Event()
196+
completion_events["task_3"] = asyncio.Event()
197+
174198
# Start the logging worker
175199
logging_worker.start()
176-
200+
201+
# Give the worker a moment to start
202+
await asyncio.sleep(0.1)
203+
177204
# Create two separate contexts and enqueue tasks from each
178-
205+
179206
# Context 1: Set context var to "context_1"
180207
ctx1 = contextvars.copy_context()
181208
ctx1.run(test_context_var.set, "context_1")
182209
ctx1.run(logging_worker.enqueue, test_task("task_1"))
183-
184-
# Context 2: Set context var to "context_2"
210+
211+
# Context 2: Set context var to "context_2"
185212
ctx2 = contextvars.copy_context()
186213
ctx2.run(test_context_var.set, "context_2")
187214
ctx2.run(logging_worker.enqueue, test_task("task_2"))
188-
215+
189216
# Context 3: No context variable set (should get LookupError)
190217
ctx3 = contextvars.copy_context()
191218
ctx3.run(logging_worker.enqueue, test_task("task_3"))
192-
193-
# Wait for all tasks to be processed
194-
await asyncio.sleep(0.5)
195-
219+
220+
# Wait for all tasks to complete with a reasonable timeout
221+
try:
222+
await asyncio.wait_for(
223+
asyncio.gather(
224+
completion_events["task_1"].wait(),
225+
completion_events["task_2"].wait(),
226+
completion_events["task_3"].wait(),
227+
),
228+
timeout=5.0,
229+
)
230+
except asyncio.TimeoutError:
231+
pytest.fail("Tasks did not complete within timeout")
232+
196233
# Stop the worker
197234
await logging_worker.stop()
198-
235+
199236
# Sort results by task_id for consistent testing
200-
task_results.sort(key=lambda x: x['task_id'])
201-
237+
task_results.sort(key=lambda x: x["task_id"])
238+
202239
# Verify that each task saw its own context
203-
assert len(task_results) == 3, f"Expected 3 results, got {len(task_results)}"
204-
240+
assert (
241+
len(task_results) == 3
242+
), f"Expected 3 results, got {len(task_results)}: {task_results}"
243+
205244
# Task 1 should see "context_1"
206-
task1_result = next((r for r in task_results if r['task_id'] == 'task_1'), None)
245+
task1_result = next((r for r in task_results if r["task_id"] == "task_1"), None)
207246
assert task1_result is not None, "Task 1 result not found"
208-
assert task1_result['context_accessible'] is True, "Task 1 should have access to context variable"
209-
assert task1_result['context_value'] == "context_1", f"Task 1 should see 'context_1', got: {task1_result['context_value']}"
210-
247+
assert (
248+
task1_result["context_accessible"] is True
249+
), "Task 1 should have access to context variable"
250+
assert (
251+
task1_result["context_value"] == "context_1"
252+
), f"Task 1 should see 'context_1', got: {task1_result['context_value']}"
253+
211254
# Task 2 should see "context_2"
212-
task2_result = next((r for r in task_results if r['task_id'] == 'task_2'), None)
255+
task2_result = next((r for r in task_results if r["task_id"] == "task_2"), None)
213256
assert task2_result is not None, "Task 2 result not found"
214-
assert task2_result['context_accessible'] is True, "Task 2 should have access to context variable"
215-
assert task2_result['context_value'] == "context_2", f"Task 2 should see 'context_2', got: {task2_result['context_value']}"
216-
257+
assert (
258+
task2_result["context_accessible"] is True
259+
), "Task 2 should have access to context variable"
260+
assert (
261+
task2_result["context_value"] == "context_2"
262+
), f"Task 2 should see 'context_2', got: {task2_result['context_value']}"
263+
217264
# Task 3 should not have access to the context variable
218-
task3_result = next((r for r in task_results if r['task_id'] == 'task_3'), None)
265+
task3_result = next((r for r in task_results if r["task_id"] == "task_3"), None)
219266
assert task3_result is not None, "Task 3 result not found"
220-
assert task3_result['context_accessible'] is False, "Task 3 should not have access to context variable"
267+
assert (
268+
task3_result["context_accessible"] is False
269+
), "Task 3 should not have access to context variable"

0 commit comments

Comments
 (0)