diff --git a/asgiref/sync.py b/asgiref/sync.py index 377075d1..7f9e6f50 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -452,22 +452,29 @@ async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: func = context.run task_context: List[asyncio.Task[Any]] = [] - # Run the code in the right thread - exec_coro = loop.run_in_executor( - executor, - functools.partial( - self.thread_handler, - loop, - sys.exc_info(), - task_context, - func, - child, - ), - ) + try: + # Run the code in the right thread + exec_coro = loop.run_in_executor( + executor, + functools.partial( + self.thread_handler, + loop, + sys.exc_info(), + task_context, + func, + child, + ), + ) + except Exception: + _restore_context(context) + self.deadlock_context.set(False) + raise + ret: _R try: ret = await asyncio.shield(exec_coro) except asyncio.CancelledError: + # catch CancelledError only in await cancel_parent = True try: task = task_context[0] diff --git a/tests/test_sync.py b/tests/test_sync.py index 0c67308c..c8253c2b 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -8,7 +8,7 @@ from concurrent.futures import ThreadPoolExecutor from functools import wraps from typing import Any -from unittest import TestCase +from unittest import TestCase, mock import pytest @@ -197,6 +197,21 @@ def test_method(self): assert method.__self__ == instance +@pytest.mark.asyncio +@pytest.mark.parametrize("exc_cls", (RuntimeError, ValueError, Exception)) +async def test_sync_to_async_broken_executor(exc_cls): + """ + Tests sync_to_async catch error in executor and avoid deadlock + """ + with mock.patch.object(ThreadPoolExecutor, "submit") as mock_run: + mock_run.side_effect = exc_cls("Test Error") + async_function = sync_to_async(lambda: None, thread_sensitive=True) + with pytest.raises(exc_cls, match="Test Error"): + await async_function() + with pytest.raises(exc_cls, match="Test Error"): + await async_function() + + @pytest.mark.asyncio async def test_async_to_sync_to_async(): """