Skip to content

Commit 7b3ce47

Browse files
committed
Simplify code leveraging the fact that ContextVar is automatically thread-local
1 parent 16d4ae2 commit 7b3ce47

File tree

1 file changed

+21
-77
lines changed

1 file changed

+21
-77
lines changed

asgiref/local.py

Lines changed: 21 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import asyncio
2-
import contextlib
31
import contextvars
42
import threading
53
from typing import Any, Union
@@ -9,36 +7,40 @@ class _CVar:
97
"""Storage utility for Local."""
108

119
def __init__(self) -> None:
10+
self._thread_lock = threading.RLock()
1211
self._data: dict[str, contextvars.ContextVar[Any]] = {}
1312

1413
def __getattr__(self, key: str) -> Any:
15-
try:
16-
var = self._data[key]
17-
except KeyError:
18-
raise AttributeError(f"{self!r} object has no attribute {key!r}")
14+
with self._thread_lock:
15+
try:
16+
var = self._data[key]
17+
except KeyError:
18+
raise AttributeError(f"{self!r} object has no attribute {key!r}")
1919

2020
try:
2121
return var.get()
2222
except LookupError:
2323
raise AttributeError(f"{self!r} object has no attribute {key!r}")
2424

2525
def __setattr__(self, key: str, value: Any) -> None:
26-
if key == "_data":
26+
if key in ("_data", "_thread_lock"):
2727
return super().__setattr__(key, value)
2828

29-
var = self._data.get(key)
30-
if var is None:
31-
self._data[key] = var = contextvars.ContextVar(key)
29+
with self._thread_lock:
30+
var = self._data.get(key)
31+
if var is None:
32+
self._data[key] = var = contextvars.ContextVar(key)
3233
var.set(value)
3334

3435
def __delattr__(self, key: str) -> None:
35-
if key in self._data:
36-
del self._data[key]
37-
else:
38-
raise AttributeError(f"{self!r} object has no attribute {key!r}")
36+
with self._thread_lock:
37+
if key in self._data:
38+
del self._data[key]
39+
else:
40+
raise AttributeError(f"{self!r} object has no attribute {key!r}")
3941

4042

41-
class Local:
43+
def Local(thread_critical: bool = False) -> Union[threading.local, _CVar]:
4244
"""Local storage for async tasks.
4345
4446
This is a namespace object (similar to `threading.local`) where data is
@@ -65,65 +67,7 @@ class Local:
6567
6668
Unlike plain `contextvars` objects, this utility is threadsafe.
6769
"""
68-
69-
def __init__(self, thread_critical: bool = False) -> None:
70-
self._thread_critical = thread_critical
71-
self._thread_lock = threading.RLock()
72-
73-
self._storage: "Union[threading.local, _CVar]"
74-
75-
if thread_critical:
76-
# Thread-local storage
77-
self._storage = threading.local()
78-
else:
79-
# Contextvar storage
80-
self._storage = _CVar()
81-
82-
@contextlib.contextmanager
83-
def _lock_storage(self):
84-
# Thread safe access to storage
85-
if self._thread_critical:
86-
try:
87-
# this is a test for are we in a async or sync
88-
# thread - will raise RuntimeError if there is
89-
# no current loop
90-
asyncio.get_running_loop()
91-
except RuntimeError:
92-
# We are in a sync thread, the storage is
93-
# just the plain thread local (i.e, "global within
94-
# this thread" - it doesn't matter where you are
95-
# in a call stack you see the same storage)
96-
yield self._storage
97-
else:
98-
# We are in an async thread - storage is still
99-
# local to this thread, but additionally should
100-
# behave like a context var (is only visible with
101-
# the same async call stack)
102-
103-
# Ensure context exists in the current thread
104-
if not hasattr(self._storage, "cvar"):
105-
self._storage.cvar = _CVar()
106-
107-
# self._storage is a thread local, so the members
108-
# can't be accessed in another thread (we don't
109-
# need any locks)
110-
yield self._storage.cvar
111-
else:
112-
# Lock for thread_critical=False as other threads
113-
# can access the exact same storage object
114-
with self._thread_lock:
115-
yield self._storage
116-
117-
def __getattr__(self, key):
118-
with self._lock_storage() as storage:
119-
return getattr(storage, key)
120-
121-
def __setattr__(self, key, value):
122-
if key in ("_local", "_storage", "_thread_critical", "_thread_lock"):
123-
return super().__setattr__(key, value)
124-
with self._lock_storage() as storage:
125-
setattr(storage, key, value)
126-
127-
def __delattr__(self, key):
128-
with self._lock_storage() as storage:
129-
delattr(storage, key)
70+
if thread_critical:
71+
return threading.local()
72+
else:
73+
return _CVar()

0 commit comments

Comments
 (0)