diff --git a/asgiref/sync.py b/asgiref/sync.py index 53f1900b..9a43c4a4 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -180,12 +180,7 @@ class SyncToAsync: rather than just blocking. """ - # If they've set ASGI_THREADS, update the default asyncio executor for now - if "ASGI_THREADS" in os.environ: - loop = asyncio.get_event_loop() - loop.set_default_executor( - ThreadPoolExecutor(max_workers=int(os.environ["ASGI_THREADS"])) - ) + executor = None # Maps launched threads to the coroutines that spawned them launch_map = {} @@ -205,6 +200,14 @@ def __init__(self, func, thread_sensitive=False): except AttributeError: pass + if self.__class__.executor is None: + executor_kwargs = {} + if "ASGI_THREADS" in os.environ: + executor_kwargs["max_workers"] = int(os.environ["ASGI_THREADS"]) + if sys.version_info >= (3, 6): + executor_kwargs["thread_name_prefix"] = "sync_to_async" + self.__class__.executor = ThreadPoolExecutor(**executor_kwargs) + async def __call__(self, *args, **kwargs): loop = asyncio.get_event_loop() @@ -217,7 +220,7 @@ async def __call__(self, *args, **kwargs): # Otherwise, we run it in a fixed single thread executor = self.single_thread_executor else: - executor = None # Use default + executor = self.__class__.executor # Use default if contextvars is not None: context = contextvars.copy_context() diff --git a/tests/test_sync.py b/tests/test_sync.py index 5a6e3547..e1f665ba 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -10,7 +10,7 @@ @pytest.mark.asyncio -async def test_sync_to_async(): +async def test_sync_to_async(monkeypatch): """ Tests we can call sync functions from an async thread (even if the number of thread workers is less than the number of calls) @@ -29,18 +29,44 @@ def sync_function(): end = time.monotonic() assert result == 42 assert end - start >= 1 + # Set workers to 1, call it twice and make sure that works right - loop = asyncio.get_event_loop() - old_executor = loop._default_executor - loop.set_default_executor(ThreadPoolExecutor(max_workers=1)) - try: - start = time.monotonic() - await asyncio.wait([async_function(), async_function()]) - end = time.monotonic() - # It should take at least 2 seconds as there's only one worker. - assert end - start >= 2 - finally: - loop.set_default_executor(old_executor) + monkeypatch.setattr(sync_to_async, "executor", ThreadPoolExecutor(max_workers=1)) + async_function = sync_to_async(sync_function) + start = time.monotonic() + await asyncio.wait([async_function(), async_function()]) + end = time.monotonic() + # It should take at least 2 seconds as there's only one worker. + assert end - start >= 2 + + +@pytest.mark.asyncio +async def test_sync_to_async_ASGI_THREADS(monkeypatch): + def sync_function(): + time.sleep(0.5) + return 42 + + # Set workers to 1 via env, call it twice and make sure that works right + monkeypatch.setenv("ASGI_THREADS", 1) + monkeypatch.setattr(sync_to_async, "executor", None) + + async_function = sync_to_async(sync_function) + assert async_function.executor._max_workers == 1 + start = time.monotonic() + await asyncio.wait([async_function(), async_function()]) + end = time.monotonic() + # It should take at least 1 second as there's only one worker. + assert end - start >= 1 + + # Uses existing executor instance on class. + monkeypatch.setenv("ASGI_THREADS", 99) + orig_executor = async_function.executor + async_function = sync_to_async(sync_function) + assert async_function.executor is orig_executor + + monkeypatch.setattr(sync_to_async, "executor", None) + async_function = sync_to_async(sync_function) + assert async_function.executor is not orig_executor @pytest.mark.asyncio