Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions asgiref/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,7 @@ class SyncToAsync:
rather than just blocking.
"""

# If they've set ASGI_THREADS, update the default asyncio executor for now
if "ASGI_THREADS" in os.environ:
loop = asyncio.get_event_loop()
loop.set_default_executor(
ThreadPoolExecutor(max_workers=int(os.environ["ASGI_THREADS"]))
)
executor = None

# Maps launched threads to the coroutines that spawned them
launch_map = {}
Expand All @@ -205,6 +200,14 @@ def __init__(self, func, thread_sensitive=False):
except AttributeError:
pass

if self.__class__.executor is None:
executor_kwargs = {}
if "ASGI_THREADS" in os.environ:
executor_kwargs["max_workers"] = int(os.environ["ASGI_THREADS"])
if sys.version_info >= (3, 6):
executor_kwargs["thread_name_prefix"] = "sync_to_async"
self.__class__.executor = ThreadPoolExecutor(**executor_kwargs)

async def __call__(self, *args, **kwargs):
loop = asyncio.get_event_loop()

Expand All @@ -217,7 +220,7 @@ async def __call__(self, *args, **kwargs):
# Otherwise, we run it in a fixed single thread
executor = self.single_thread_executor
else:
executor = None # Use default
executor = self.__class__.executor # Use default

if contextvars is not None:
context = contextvars.copy_context()
Expand Down
50 changes: 38 additions & 12 deletions tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@pytest.mark.asyncio
async def test_sync_to_async():
async def test_sync_to_async(monkeypatch):
"""
Tests we can call sync functions from an async thread
(even if the number of thread workers is less than the number of calls)
Expand All @@ -29,18 +29,44 @@ def sync_function():
end = time.monotonic()
assert result == 42
assert end - start >= 1

# Set workers to 1, call it twice and make sure that works right
loop = asyncio.get_event_loop()
old_executor = loop._default_executor
loop.set_default_executor(ThreadPoolExecutor(max_workers=1))
try:
start = time.monotonic()
await asyncio.wait([async_function(), async_function()])
end = time.monotonic()
# It should take at least 2 seconds as there's only one worker.
assert end - start >= 2
finally:
loop.set_default_executor(old_executor)
monkeypatch.setattr(sync_to_async, "executor", ThreadPoolExecutor(max_workers=1))
async_function = sync_to_async(sync_function)
start = time.monotonic()
await asyncio.wait([async_function(), async_function()])
end = time.monotonic()
# It should take at least 2 seconds as there's only one worker.
assert end - start >= 2


@pytest.mark.asyncio
async def test_sync_to_async_ASGI_THREADS(monkeypatch):
def sync_function():
time.sleep(0.5)
return 42

# Set workers to 1 via env, call it twice and make sure that works right
monkeypatch.setenv("ASGI_THREADS", 1)
monkeypatch.setattr(sync_to_async, "executor", None)

async_function = sync_to_async(sync_function)
assert async_function.executor._max_workers == 1
start = time.monotonic()
await asyncio.wait([async_function(), async_function()])
end = time.monotonic()
# It should take at least 1 second as there's only one worker.
assert end - start >= 1

# Uses existing executor instance on class.
monkeypatch.setenv("ASGI_THREADS", 99)
orig_executor = async_function.executor
async_function = sync_to_async(sync_function)
assert async_function.executor is orig_executor
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about this though: it might actually make sense to use a separate executor for every decorator?
This could also be achieved through a new kwarg.


monkeypatch.setattr(sync_to_async, "executor", None)
async_function = sync_to_async(sync_function)
assert async_function.executor is not orig_executor


@pytest.mark.asyncio
Expand Down