Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions asgiref/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,22 +452,29 @@ async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
func = context.run
task_context: List[asyncio.Task[Any]] = []

# Run the code in the right thread
exec_coro = loop.run_in_executor(
executor,
functools.partial(
self.thread_handler,
loop,
sys.exc_info(),
task_context,
func,
child,
),
)
try:
# Run the code in the right thread
exec_coro = loop.run_in_executor(
executor,
functools.partial(
self.thread_handler,
loop,
sys.exc_info(),
task_context,
func,
child,
),
)
except Exception:
_restore_context(context)
self.deadlock_context.set(False)
raise

ret: _R
try:
ret = await asyncio.shield(exec_coro)
except asyncio.CancelledError:
# catch CancelledError only in await
cancel_parent = True
try:
task = task_context[0]
Expand Down
17 changes: 16 additions & 1 deletion tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from concurrent.futures import ThreadPoolExecutor
from functools import wraps
from typing import Any
from unittest import TestCase
from unittest import TestCase, mock

import pytest

Expand Down Expand Up @@ -197,6 +197,21 @@ def test_method(self):
assert method.__self__ == instance


@pytest.mark.asyncio
@pytest.mark.parametrize("exc_cls", (RuntimeError, ValueError, Exception))
async def test_sync_to_async_broken_executor(exc_cls):
"""
Tests sync_to_async catch error in executor and avoid deadlock
"""
with mock.patch.object(ThreadPoolExecutor, "submit") as mock_run:
mock_run.side_effect = exc_cls("Test Error")
async_function = sync_to_async(lambda: None, thread_sensitive=True)
with pytest.raises(exc_cls, match="Test Error"):
await async_function()
with pytest.raises(exc_cls, match="Test Error"):
await async_function()


@pytest.mark.asyncio
async def test_async_to_sync_to_async():
"""
Expand Down