11from __future__ import annotations
22
3- from abc import abstractmethod
3+ from abc import ABC , abstractmethod
44from collections import defaultdict
5- from collections .abc import Iterator , MutableMapping
5+ from collections .abc import AsyncIterator , Iterator , MutableMapping
66from contextlib import (
77 AsyncContextDecorator ,
88 AsyncExitStack ,
99 ContextDecorator ,
1010 ExitStack ,
11+ asynccontextmanager ,
1112 contextmanager ,
1213)
1314from contextvars import ContextVar
3132)
3233
3334
34- @dataclass (repr = False , frozen = True , slots = True )
35- class _ActiveScope :
35+ @dataclass (repr = False , slots = True )
36+ class _ScopeState :
3637 # Shouldn't be instantiated outside `__SCOPES`.
3738
38- context_var : ContextVar [Scope ] = field (
39+ __context_var : ContextVar [Scope ] = field (
3940 default_factory = lambda : ContextVar (f"scope@{ new_short_key ()} " ),
4041 init = False ,
4142 )
42- references : set [Scope ] = field (
43+ __references : set [Scope ] = field (
4344 default_factory = set ,
4445 init = False ,
4546 )
47+ __shared_value : Scope | None = field (
48+ default = None ,
49+ init = False ,
50+ )
4651
47- def to_tuple (self ) -> tuple [ContextVar [Scope ], set [Scope ]]:
48- return self .context_var , self .references
52+ @contextmanager
53+ def bind_contextual_scope (self , scope : Scope ) -> Iterator [None ]:
54+ self .__references .add (scope )
55+ token = self .__context_var .set (scope )
4956
57+ try :
58+ yield
59+ finally :
60+ self .__context_var .reset (token )
61+ self .__references .remove (scope )
5062
51- __SCOPES : Final [defaultdict [str , _ActiveScope ]] = defaultdict (_ActiveScope )
63+ @contextmanager
64+ def bind_shared_scope (self , scope : Scope ) -> Iterator [None ]:
65+ if self .__references :
66+ raise ScopeError (
67+ "A shared scope can't be defined when one or more contextual scopes "
68+ "are defined on the same name."
69+ )
5270
71+ self .__shared_value = scope
5372
54- @contextmanager
55- def bind_scope (name : str , value : Scope ) -> Iterator [None ]:
56- context_var , references = __SCOPES [name ].to_tuple ()
73+ try :
74+ yield
75+ finally :
76+ self .__shared_value = None
5777
58- if context_var .get (None ):
59- raise ScopeAlreadyDefinedError (
60- f"Scope `{ name } ` is already defined in the current context."
61- )
78+ def get_scope (self ) -> Scope | None :
79+ return self .__context_var .get (self .__shared_value )
6280
63- references . add ( value )
64- token = context_var . set ( value )
81+ def get_active_scopes ( self ) -> tuple [ Scope , ...]:
82+ references = self . __references
6583
66- try :
84+ if shared_value := self .__shared_value :
85+ return shared_value , * references
86+
87+ return tuple (references )
88+
89+
90+ __SCOPES : Final [defaultdict [str , _ScopeState ]] = defaultdict (_ScopeState )
91+
92+
93+ @asynccontextmanager
94+ async def async_scope (name : str , * , shared : bool = False ) -> AsyncIterator [None ]:
95+ async with AsyncScope () as scope :
96+ scope .enter (_bind_scope (name , scope , shared ))
97+ yield
98+
99+
100+ @contextmanager
101+ def sync_scope (name : str , * , shared : bool = False ) -> Iterator [None ]:
102+ with SyncScope () as scope :
103+ scope .enter (_bind_scope (name , scope , shared ))
67104 yield
68- finally :
69- context_var .reset (token )
70- references .discard (value )
71- value .cache .clear ()
72105
73106
74107def get_active_scopes (name : str ) -> tuple [Scope , ...]:
75- return tuple ( __SCOPES [name ].references )
108+ return __SCOPES [name ].get_active_scopes ( )
76109
77110
78111def get_scope (name : str ) -> Scope :
79- context_var = __SCOPES [name ].context_var
112+ scope = __SCOPES [name ].get_scope ()
80113
81- try :
82- return context_var .get ()
83- except LookupError as exc :
114+ if scope is None :
84115 raise ScopeUndefinedError (
85116 f"Scope `{ name } ` isn't defined in the current context."
86- ) from exc
117+ )
118+
119+ return scope
120+
121+
122+ @contextmanager
123+ def _bind_scope (name : str , value : Scope , shared : bool ) -> Iterator [None ]:
124+ state = __SCOPES [name ]
125+
126+ if state .get_scope ():
127+ raise ScopeAlreadyDefinedError (
128+ f"Scope `{ name } ` is already defined in the current context."
129+ )
130+
131+ strategy = (
132+ state .bind_shared_scope (value ) if shared else state .bind_contextual_scope (value )
133+ )
134+
135+ try :
136+ with strategy :
137+ yield
138+ finally :
139+ value .cache .clear ()
87140
88141
89142@runtime_checkable
@@ -102,22 +155,23 @@ def enter[T](self, context_manager: ContextManager[T]) -> T:
102155
103156
104157@dataclass (repr = False , frozen = True , slots = True )
105- class AsyncScope ( AsyncContextDecorator , Scope ):
106- name : str
158+ class BaseScope [ T ]( Scope , ABC ):
159+ delegate : T
107160 cache : MutableMapping [Any , Any ] = field (
108161 default_factory = dict ,
109162 init = False ,
110163 hash = False ,
111164 )
112- __exit_stack : AsyncExitStack = field (
113- default_factory = AsyncExitStack ,
114- init = False ,
115- )
165+
166+
167+ class AsyncScope (AsyncContextDecorator , BaseScope [AsyncExitStack ]):
168+ __slots__ = ()
169+
170+ def __init__ (self ) -> None :
171+ super ().__init__ (delegate = AsyncExitStack ())
116172
117173 async def __aenter__ (self ) -> Self :
118- await self .__exit_stack .__aenter__ ()
119- lifespan = bind_scope (self .name , self )
120- self .enter (lifespan )
174+ await self .delegate .__aenter__ ()
121175 return self
122176
123177 async def __aexit__ (
@@ -126,32 +180,23 @@ async def __aexit__(
126180 exc_value : BaseException | None ,
127181 traceback : TracebackType | None ,
128182 ) -> Any :
129- return await self .__exit_stack .__aexit__ (exc_type , exc_value , traceback )
183+ return await self .delegate .__aexit__ (exc_type , exc_value , traceback )
130184
131185 async def aenter [T ](self , context_manager : AsyncContextManager [T ]) -> T :
132- return await self .__exit_stack .enter_async_context (context_manager )
186+ return await self .delegate .enter_async_context (context_manager )
133187
134188 def enter [T ](self , context_manager : ContextManager [T ]) -> T :
135- return self .__exit_stack .enter_context (context_manager )
189+ return self .delegate .enter_context (context_manager )
136190
137191
138- @dataclass (repr = False , frozen = True , slots = True )
139- class SyncScope (ContextDecorator , Scope ):
140- name : str
141- cache : MutableMapping [Any , Any ] = field (
142- default_factory = dict ,
143- init = False ,
144- hash = False ,
145- )
146- __exit_stack : ExitStack = field (
147- default_factory = ExitStack ,
148- init = False ,
149- )
192+ class SyncScope (ContextDecorator , BaseScope [ExitStack ]):
193+ __slots__ = ()
194+
195+ def __init__ (self ) -> None :
196+ super ().__init__ (delegate = ExitStack ())
150197
151198 def __enter__ (self ) -> Self :
152- self .__exit_stack .__enter__ ()
153- lifespan = bind_scope (self .name , self )
154- self .enter (lifespan )
199+ self .delegate .__enter__ ()
155200 return self
156201
157202 def __exit__ (
@@ -160,10 +205,10 @@ def __exit__(
160205 exc_value : BaseException | None ,
161206 traceback : TracebackType | None ,
162207 ) -> Any :
163- return self .__exit_stack .__exit__ (exc_type , exc_value , traceback )
208+ return self .delegate .__exit__ (exc_type , exc_value , traceback )
164209
165210 async def aenter [T ](self , context_manager : AsyncContextManager [T ]) -> T :
166211 raise ScopeError ("SyncScope doesn't support asynchronous context manager." )
167212
168213 def enter [T ](self , context_manager : ContextManager [T ]) -> T :
169- return self .__exit_stack .enter_context (context_manager )
214+ return self .delegate .enter_context (context_manager )
0 commit comments