1
- import asyncio
2
- import contextlib
3
1
import contextvars
4
2
import threading
5
3
from typing import Any , Union
@@ -9,36 +7,40 @@ class _CVar:
9
7
"""Storage utility for Local."""
10
8
11
9
def __init__ (self ) -> None :
10
+ self ._thread_lock = threading .RLock ()
12
11
self ._data : dict [str , contextvars .ContextVar [Any ]] = {}
13
12
14
13
def __getattr__ (self , key : str ) -> Any :
15
- try :
16
- var = self ._data [key ]
17
- except KeyError :
18
- raise AttributeError (f"{ self !r} object has no attribute { key !r} " )
14
+ with self ._thread_lock :
15
+ try :
16
+ var = self ._data [key ]
17
+ except KeyError :
18
+ raise AttributeError (f"{ self !r} object has no attribute { key !r} " )
19
19
20
20
try :
21
21
return var .get ()
22
22
except LookupError :
23
23
raise AttributeError (f"{ self !r} object has no attribute { key !r} " )
24
24
25
25
def __setattr__ (self , key : str , value : Any ) -> None :
26
- if key == "_data" :
26
+ if key in ( "_data" , "_thread_lock" ) :
27
27
return super ().__setattr__ (key , value )
28
28
29
- var = self ._data .get (key )
30
- if var is None :
31
- self ._data [key ] = var = contextvars .ContextVar (key )
29
+ with self ._thread_lock :
30
+ var = self ._data .get (key )
31
+ if var is None :
32
+ self ._data [key ] = var = contextvars .ContextVar (key )
32
33
var .set (value )
33
34
34
35
def __delattr__ (self , key : str ) -> None :
35
- if key in self ._data :
36
- del self ._data [key ]
37
- else :
38
- raise AttributeError (f"{ self !r} object has no attribute { key !r} " )
36
+ with self ._thread_lock :
37
+ if key in self ._data :
38
+ del self ._data [key ]
39
+ else :
40
+ raise AttributeError (f"{ self !r} object has no attribute { key !r} " )
39
41
40
42
41
- class Local :
43
+ def Local ( thread_critical : bool = False ) -> Union [ threading . local , _CVar ] :
42
44
"""Local storage for async tasks.
43
45
44
46
This is a namespace object (similar to `threading.local`) where data is
@@ -65,65 +67,7 @@ class Local:
65
67
66
68
Unlike plain `contextvars` objects, this utility is threadsafe.
67
69
"""
68
-
69
- def __init__ (self , thread_critical : bool = False ) -> None :
70
- self ._thread_critical = thread_critical
71
- self ._thread_lock = threading .RLock ()
72
-
73
- self ._storage : "Union[threading.local, _CVar]"
74
-
75
- if thread_critical :
76
- # Thread-local storage
77
- self ._storage = threading .local ()
78
- else :
79
- # Contextvar storage
80
- self ._storage = _CVar ()
81
-
82
- @contextlib .contextmanager
83
- def _lock_storage (self ):
84
- # Thread safe access to storage
85
- if self ._thread_critical :
86
- try :
87
- # this is a test for are we in a async or sync
88
- # thread - will raise RuntimeError if there is
89
- # no current loop
90
- asyncio .get_running_loop ()
91
- except RuntimeError :
92
- # We are in a sync thread, the storage is
93
- # just the plain thread local (i.e, "global within
94
- # this thread" - it doesn't matter where you are
95
- # in a call stack you see the same storage)
96
- yield self ._storage
97
- else :
98
- # We are in an async thread - storage is still
99
- # local to this thread, but additionally should
100
- # behave like a context var (is only visible with
101
- # the same async call stack)
102
-
103
- # Ensure context exists in the current thread
104
- if not hasattr (self ._storage , "cvar" ):
105
- self ._storage .cvar = _CVar ()
106
-
107
- # self._storage is a thread local, so the members
108
- # can't be accessed in another thread (we don't
109
- # need any locks)
110
- yield self ._storage .cvar
111
- else :
112
- # Lock for thread_critical=False as other threads
113
- # can access the exact same storage object
114
- with self ._thread_lock :
115
- yield self ._storage
116
-
117
- def __getattr__ (self , key ):
118
- with self ._lock_storage () as storage :
119
- return getattr (storage , key )
120
-
121
- def __setattr__ (self , key , value ):
122
- if key in ("_local" , "_storage" , "_thread_critical" , "_thread_lock" ):
123
- return super ().__setattr__ (key , value )
124
- with self ._lock_storage () as storage :
125
- setattr (storage , key , value )
126
-
127
- def __delattr__ (self , key ):
128
- with self ._lock_storage () as storage :
129
- delattr (storage , key )
70
+ if thread_critical :
71
+ return threading .local ()
72
+ else :
73
+ return _CVar ()
0 commit comments