33import itertools
44from abc import ABC , abstractmethod
55from collections import defaultdict
6- from collections .abc import AsyncIterator , Iterator , Mapping , MutableMapping
6+ from collections .abc import (
7+ AsyncIterator ,
8+ Collection ,
9+ Iterator ,
10+ Mapping ,
11+ MutableMapping ,
12+ )
713from contextlib import AsyncExitStack , ExitStack , asynccontextmanager , contextmanager
814from contextvars import ContextVar
915from dataclasses import dataclass , field
@@ -47,12 +53,12 @@ def get_default(cls) -> ScopeKind:
4753
4854
4955@runtime_checkable
50- class ScopeState (Protocol ):
56+ class ScopeResolver (Protocol ):
5157 __slots__ = ()
5258
5359 @property
5460 @abstractmethod
55- def active_scopes (self ) -> Iterator [Scope ]:
61+ def active_scopes (self ) -> Collection [Scope ]:
5662 raise NotImplementedError
5763
5864 @abstractmethod
@@ -65,8 +71,8 @@ def get_scope(self) -> Scope | None:
6571
6672
6773@dataclass (repr = False , frozen = True , slots = True )
68- class _ContextualScopeState ( ScopeState ):
69- # Shouldn't be instantiated outside `__CONTEXTUAL_SCOPES `.
74+ class _ContextualScopeResolver ( ScopeResolver ):
75+ # Shouldn't be instantiated outside `__scope_resolvers `.
7076
7177 __context_var : ContextVar [Scope ] = field (
7278 default_factory = lambda : ContextVar (f"scope@{ new_short_key ()} " ),
@@ -78,8 +84,8 @@ class _ContextualScopeState(ScopeState):
7884 )
7985
8086 @property
81- def active_scopes (self ) -> Iterator [Scope ]:
82- return iter ( self .__references )
87+ def active_scopes (self ) -> Collection [Scope ]:
88+ return self .__references
8389
8490 @contextmanager
8591 def bind (self , scope : Scope ) -> Iterator [None ]:
@@ -97,13 +103,13 @@ def get_scope(self) -> Scope | None:
97103
98104
99105@dataclass (repr = False , slots = True )
100- class _SharedScopeState ( ScopeState ):
106+ class _SharedScopeResolver ( ScopeResolver ):
101107 __scope : Scope | None = field (default = None )
102108
103109 @property
104- def active_scopes (self ) -> Iterator [Scope ]:
105- if scope : = self .__scope :
106- yield scope
110+ def active_scopes (self ) -> Collection [Scope ]:
111+ scope = self .__scope
112+ return () if scope is None else ( scope ,)
107113
108114 @contextmanager
109115 def bind (self , scope : Scope ) -> Iterator [None ]:
@@ -118,12 +124,10 @@ def get_scope(self) -> Scope | None:
118124 return self .__scope
119125
120126
121- __CONTEXTUAL_SCOPES : Final [Mapping [str , ScopeState ]] = defaultdict (
122- _ContextualScopeState ,
123- )
124- __SHARED_SCOPES : Final [Mapping [str , ScopeState ]] = defaultdict (
125- _SharedScopeState ,
126- )
127+ __scope_resolvers : Final [Mapping [str , Mapping [str , ScopeResolver ]]] = {
128+ ScopeKind .CONTEXTUAL : defaultdict (_ContextualScopeResolver ),
129+ ScopeKind .SHARED : defaultdict (_SharedScopeResolver ),
130+ }
127131
128132
129133@asynccontextmanager
@@ -150,15 +154,6 @@ def define_scope(
150154 yield facade
151155
152156
153- def get_active_scopes (name : str ) -> tuple [Scope , ...]:
154- active_scopes = (
155- state .active_scopes
156- for states in (__CONTEXTUAL_SCOPES , __SHARED_SCOPES )
157- if (state := states .get (name ))
158- )
159- return tuple (itertools .chain .from_iterable (active_scopes ))
160-
161-
162157if TYPE_CHECKING : # pragma: no cover
163158
164159 @overload
@@ -169,9 +164,9 @@ def get_scope[T](name: str, default: T) -> Scope | T: ...
169164
170165
171166def get_scope [T ](name : str , default : T | EllipsisType = ...) -> Scope | T :
172- for states in ( __CONTEXTUAL_SCOPES , __SHARED_SCOPES ):
173- state = states .get (name )
174- if state and (scope := state .get_scope ()):
167+ for resolvers in __scope_resolvers . values ( ):
168+ resolver = resolvers .get (name )
169+ if resolver and (scope := resolver .get_scope ()):
175170 return scope
176171
177172 if default is Ellipsis :
@@ -183,12 +178,16 @@ def get_scope[T](name: str, default: T | EllipsisType = ...) -> Scope | T:
183178
184179
185180def in_scope_cache (key : SlotKey [Any ], scope_name : str ) -> bool :
186- return any (key in scope .cache for scope in get_active_scopes (scope_name ))
181+ return any (key in scope .cache for scope in iter_active_scopes (scope_name ))
187182
188183
189- def remove_scoped_values (key : SlotKey [Any ], scope_name : str ) -> None :
190- for scope in get_active_scopes (scope_name ):
191- scope .cache .pop (key , None )
184+ def iter_active_scopes (name : str ) -> Iterator [Scope ]:
185+ active_scopes = (
186+ resolver .active_scopes
187+ for resolvers in __scope_resolvers .values ()
188+ if (resolver := resolvers .get (name ))
189+ )
190+ return itertools .chain .from_iterable (active_scopes )
192191
193192
194193@contextmanager
@@ -198,28 +197,17 @@ def _bind_scope(
198197 kind : ScopeKind | ScopeKindStr ,
199198 threadsafe : bool | None ,
200199) -> Iterator [ScopeFacade ]:
200+ kind = ScopeKind (kind )
201201 lock = get_lock (threadsafe )
202202
203203 with lock :
204- match ScopeKind (kind ):
205- case ScopeKind .CONTEXTUAL :
206- is_already_defined = bool (get_scope (name , default = None ))
207- states = __CONTEXTUAL_SCOPES
208-
209- case ScopeKind .SHARED :
210- is_already_defined = bool (get_active_scopes (name ))
211- states = __SHARED_SCOPES
212-
213- case _:
214- raise NotImplementedError
215-
216- if is_already_defined :
204+ if next (iter_active_scopes (name ), None ):
217205 raise ScopeAlreadyDefinedError (
218206 f"Scope `{ name } ` is already defined in the current context."
219207 )
220208
221209 stack = ExitStack ()
222- stack .enter_context (states [name ].bind (scope ))
210+ stack .enter_context (__scope_resolvers [ kind ] [name ].bind (scope ))
223211
224212 try :
225213 yield _UserScope (scope , lock )
0 commit comments