33import itertools
44from abc import ABC , abstractmethod
55from collections import defaultdict
6- from collections .abc import AsyncIterator , Iterator , Mapping , MutableMapping
6+ from collections .abc import AsyncIterator , Collection , Iterator , Mapping , MutableMapping
77from contextlib import AsyncExitStack , ExitStack , asynccontextmanager , contextmanager
88from contextvars import ContextVar
99from dataclasses import dataclass , field
@@ -47,12 +47,12 @@ def get_default(cls) -> ScopeKind:
4747
4848
4949@runtime_checkable
50- class ScopeState (Protocol ):
50+ class ScopeResolver (Protocol ):
5151 __slots__ = ()
5252
5353 @property
5454 @abstractmethod
55- def active_scopes (self ) -> Iterator [Scope ]:
55+ def active_scopes (self ) -> Collection [Scope ]:
5656 raise NotImplementedError
5757
5858 @abstractmethod
@@ -65,8 +65,8 @@ def get_scope(self) -> Scope | None:
6565
6666
6767@dataclass (repr = False , frozen = True , slots = True )
68- class _ContextualScopeState ( ScopeState ):
69- # Shouldn't be instantiated outside `__CONTEXTUAL_SCOPES `.
68+ class _ContextualScopeResolver ( ScopeResolver ):
69+ # Shouldn't be instantiated outside `__scope_resolvers `.
7070
7171 __context_var : ContextVar [Scope ] = field (
7272 default_factory = lambda : ContextVar (f"scope@{ new_short_key ()} " ),
@@ -78,8 +78,8 @@ class _ContextualScopeState(ScopeState):
7878 )
7979
8080 @property
81- def active_scopes (self ) -> Iterator [Scope ]:
82- return iter ( self .__references )
81+ def active_scopes (self ) -> Collection [Scope ]:
82+ return self .__references
8383
8484 @contextmanager
8585 def bind (self , scope : Scope ) -> Iterator [None ]:
@@ -97,13 +97,13 @@ def get_scope(self) -> Scope | None:
9797
9898
9999@dataclass (repr = False , slots = True )
100- class _SharedScopeState ( ScopeState ):
100+ class _SharedScopeResolver ( ScopeResolver ):
101101 __scope : Scope | None = field (default = None )
102102
103103 @property
104- def active_scopes (self ) -> Iterator [Scope ]:
105- if scope : = self .__scope :
106- yield scope
104+ def active_scopes (self ) -> Collection [Scope ]:
105+ scope = self .__scope
106+ return () if scope is None else ( scope ,)
107107
108108 @contextmanager
109109 def bind (self , scope : Scope ) -> Iterator [None ]:
@@ -118,12 +118,10 @@ def get_scope(self) -> Scope | None:
118118 return self .__scope
119119
120120
121- __CONTEXTUAL_SCOPES : Final [Mapping [str , ScopeState ]] = defaultdict (
122- _ContextualScopeState ,
123- )
124- __SHARED_SCOPES : Final [Mapping [str , ScopeState ]] = defaultdict (
125- _SharedScopeState ,
126- )
121+ __scope_resolvers : Final [Mapping [str , Mapping [str , ScopeResolver ]]] = {
122+ ScopeKind .CONTEXTUAL : defaultdict (_ContextualScopeResolver ),
123+ ScopeKind .SHARED : defaultdict (_SharedScopeResolver ),
124+ }
127125
128126
129127@asynccontextmanager
@@ -150,15 +148,6 @@ def define_scope(
150148 yield facade
151149
152150
153- def get_active_scopes (name : str ) -> tuple [Scope , ...]:
154- active_scopes = (
155- state .active_scopes
156- for states in (__CONTEXTUAL_SCOPES , __SHARED_SCOPES )
157- if (state := states .get (name ))
158- )
159- return tuple (itertools .chain .from_iterable (active_scopes ))
160-
161-
162151if TYPE_CHECKING : # pragma: no cover
163152
164153 @overload
@@ -169,9 +158,9 @@ def get_scope[T](name: str, default: T) -> Scope | T: ...
169158
170159
171160def get_scope [T ](name : str , default : T | EllipsisType = ...) -> Scope | T :
172- for states in ( __CONTEXTUAL_SCOPES , __SHARED_SCOPES ):
173- state = states .get (name )
174- if state and (scope := state .get_scope ()):
161+ for resolvers in __scope_resolvers . values ( ):
162+ resolver = resolvers .get (name )
163+ if resolver and (scope := resolver .get_scope ()):
175164 return scope
176165
177166 if default is Ellipsis :
@@ -183,12 +172,16 @@ def get_scope[T](name: str, default: T | EllipsisType = ...) -> Scope | T:
183172
184173
185174def in_scope_cache (key : SlotKey [Any ], scope_name : str ) -> bool :
186- return any (key in scope .cache for scope in get_active_scopes (scope_name ))
175+ return any (key in scope .cache for scope in iter_active_scopes (scope_name ))
187176
188177
189- def remove_scoped_values (key : SlotKey [Any ], scope_name : str ) -> None :
190- for scope in get_active_scopes (scope_name ):
191- scope .cache .pop (key , None )
178+ def iter_active_scopes (name : str ) -> Iterator [Scope ]:
179+ active_scopes = (
180+ resolver .active_scopes
181+ for resolvers in __scope_resolvers .values ()
182+ if (resolver := resolvers .get (name ))
183+ )
184+ return itertools .chain .from_iterable (active_scopes )
192185
193186
194187@contextmanager
@@ -198,28 +191,17 @@ def _bind_scope(
198191 kind : ScopeKind | ScopeKindStr ,
199192 threadsafe : bool | None ,
200193) -> Iterator [ScopeFacade ]:
194+ kind = ScopeKind (kind )
201195 lock = get_lock (threadsafe )
202196
203197 with lock :
204- match ScopeKind (kind ):
205- case ScopeKind .CONTEXTUAL :
206- is_already_defined = bool (get_scope (name , default = None ))
207- states = __CONTEXTUAL_SCOPES
208-
209- case ScopeKind .SHARED :
210- is_already_defined = bool (get_active_scopes (name ))
211- states = __SHARED_SCOPES
212-
213- case _:
214- raise NotImplementedError
215-
216- if is_already_defined :
198+ if get_scope (name , default = None ):
217199 raise ScopeAlreadyDefinedError (
218200 f"Scope `{ name } ` is already defined in the current context."
219201 )
220202
221203 stack = ExitStack ()
222- stack .enter_context (states [name ].bind (scope ))
204+ stack .enter_context (__scope_resolvers [ kind ] [name ].bind (scope ))
223205
224206 try :
225207 yield _UserScope (scope , lock )
0 commit comments