Skip to content

Commit 85d2445

Browse files
authored
Fix Local behaviour with asyncio Task (#478)
Redid CVar storage as a dict of contextvars
1 parent 8e39bcc commit 85d2445

File tree

2 files changed

+52
-14
lines changed

2 files changed

+52
-14
lines changed

asgiref/local.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,38 @@
22
import contextlib
33
import contextvars
44
import threading
5-
from typing import Any, Dict, Union
5+
from typing import Any, Union
66

77

88
class _CVar:
99
"""Storage utility for Local."""
1010

1111
def __init__(self) -> None:
12-
self._data: "contextvars.ContextVar[Dict[str, Any]]" = contextvars.ContextVar(
13-
"asgiref.local"
14-
)
12+
self._data: dict[str, contextvars.ContextVar[Any]] = {}
1513

16-
def __getattr__(self, key):
17-
storage_object = self._data.get({})
14+
def __getattr__(self, key: str) -> Any:
1815
try:
19-
return storage_object[key]
16+
var = self._data[key]
2017
except KeyError:
2118
raise AttributeError(f"{self!r} object has no attribute {key!r}")
2219

20+
try:
21+
return var.get()
22+
except LookupError:
23+
raise AttributeError(f"{self!r} object has no attribute {key!r}")
24+
2325
def __setattr__(self, key: str, value: Any) -> None:
2426
if key == "_data":
2527
return super().__setattr__(key, value)
2628

27-
storage_object = self._data.get({})
28-
storage_object[key] = value
29-
self._data.set(storage_object)
29+
var = self._data.get(key)
30+
if var is None:
31+
self._data[key] = var = contextvars.ContextVar(key)
32+
var.set(value)
3033

3134
def __delattr__(self, key: str) -> None:
32-
storage_object = self._data.get({})
33-
if key in storage_object:
34-
del storage_object[key]
35-
self._data.set(storage_object)
35+
if key in self._data:
36+
del self._data[key]
3637
else:
3738
raise AttributeError(f"{self!r} object has no attribute {key!r}")
3839

tests/test_local.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import gc
33
import threading
4+
from threading import Thread
45

56
import pytest
67

@@ -338,3 +339,39 @@ async def async_function():
338339
# inner value was set inside a new async context, meaning that
339340
# we do not see it, as context vars don't propagate up the stack
340341
assert not hasattr(test_local_not_tc, "test_value")
342+
343+
344+
def test_visibility_thread_asgiref() -> None:
345+
"""Check visibility with subthreads."""
346+
test_local = Local()
347+
test_local.value = 0
348+
349+
def _test() -> None:
350+
# Local() is cleared when changing thread
351+
assert not hasattr(test_local, "value")
352+
setattr(test_local, "value", 1)
353+
assert test_local.value == 1
354+
355+
thread = Thread(target=_test)
356+
thread.start()
357+
thread.join()
358+
359+
assert test_local.value == 0
360+
361+
362+
@pytest.mark.asyncio
363+
async def test_visibility_task() -> None:
364+
"""Check visibility with asyncio tasks."""
365+
test_local = Local()
366+
test_local.value = 0
367+
368+
async def _test() -> None:
369+
# Local is inherited when changing task
370+
assert test_local.value == 0
371+
test_local.value = 1
372+
assert test_local.value == 1
373+
374+
await asyncio.create_task(_test())
375+
376+
# Changes should not leak to the caller
377+
assert test_local.value == 0

0 commit comments

Comments
 (0)