Skip to content

Commit 364f583

Browse files
authored
refactor: Improve scope management readability
1 parent 7a45cac commit 364f583

File tree

3 files changed

+35
-66
lines changed

3 files changed

+35
-66
lines changed

injection/_core/injectables.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,7 @@
1616

1717
from injection._core.common.asynchronous import AsyncSemaphore, Caller
1818
from injection._core.common.type import InputType
19-
from injection._core.scope import (
20-
Scope,
21-
get_scope,
22-
in_scope_cache,
23-
remove_scoped_values,
24-
)
19+
from injection._core.scope import Scope, get_scope, in_scope_cache
2520
from injection._core.slots import SlotKey
2621
from injection.exceptions import EmptySlotError, InjectionError
2722

@@ -204,9 +199,6 @@ async def abuild(self, scope: Scope) -> T:
204199
def build(self, scope: Scope) -> T:
205200
return self.factory.call()
206201

207-
def unlock(self) -> None:
208-
remove_scoped_values(self.key, self.scope_name)
209-
210202

211203
@dataclass(repr=False, eq=False, frozen=True, slots=True)
212204
class ScopedSlotInjectable[T](Injectable[T]):

injection/_core/scope.py

Lines changed: 29 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import itertools
44
from abc import ABC, abstractmethod
55
from collections import defaultdict
6-
from collections.abc import AsyncIterator, Iterator, Mapping, MutableMapping
6+
from collections.abc import AsyncIterator, Collection, Iterator, Mapping, MutableMapping
77
from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
88
from contextvars import ContextVar
99
from dataclasses import dataclass, field
@@ -47,12 +47,12 @@ def get_default(cls) -> ScopeKind:
4747

4848

4949
@runtime_checkable
50-
class ScopeState(Protocol):
50+
class ScopeResolver(Protocol):
5151
__slots__ = ()
5252

5353
@property
5454
@abstractmethod
55-
def active_scopes(self) -> Iterator[Scope]:
55+
def active_scopes(self) -> Collection[Scope]:
5656
raise NotImplementedError
5757

5858
@abstractmethod
@@ -65,8 +65,8 @@ def get_scope(self) -> Scope | None:
6565

6666

6767
@dataclass(repr=False, frozen=True, slots=True)
68-
class _ContextualScopeState(ScopeState):
69-
# Shouldn't be instantiated outside `__CONTEXTUAL_SCOPES`.
68+
class _ContextualScopeResolver(ScopeResolver):
69+
# Shouldn't be instantiated outside `__scope_resolvers`.
7070

7171
__context_var: ContextVar[Scope] = field(
7272
default_factory=lambda: ContextVar(f"scope@{new_short_key()}"),
@@ -78,8 +78,8 @@ class _ContextualScopeState(ScopeState):
7878
)
7979

8080
@property
81-
def active_scopes(self) -> Iterator[Scope]:
82-
return iter(self.__references)
81+
def active_scopes(self) -> Collection[Scope]:
82+
return self.__references
8383

8484
@contextmanager
8585
def bind(self, scope: Scope) -> Iterator[None]:
@@ -97,13 +97,13 @@ def get_scope(self) -> Scope | None:
9797

9898

9999
@dataclass(repr=False, slots=True)
100-
class _SharedScopeState(ScopeState):
100+
class _SharedScopeResolver(ScopeResolver):
101101
__scope: Scope | None = field(default=None)
102102

103103
@property
104-
def active_scopes(self) -> Iterator[Scope]:
105-
if scope := self.__scope:
106-
yield scope
104+
def active_scopes(self) -> Collection[Scope]:
105+
scope = self.__scope
106+
return () if scope is None else (scope,)
107107

108108
@contextmanager
109109
def bind(self, scope: Scope) -> Iterator[None]:
@@ -118,12 +118,10 @@ def get_scope(self) -> Scope | None:
118118
return self.__scope
119119

120120

121-
__CONTEXTUAL_SCOPES: Final[Mapping[str, ScopeState]] = defaultdict(
122-
_ContextualScopeState,
123-
)
124-
__SHARED_SCOPES: Final[Mapping[str, ScopeState]] = defaultdict(
125-
_SharedScopeState,
126-
)
121+
__scope_resolvers: Final[Mapping[str, Mapping[str, ScopeResolver]]] = {
122+
ScopeKind.CONTEXTUAL: defaultdict(_ContextualScopeResolver),
123+
ScopeKind.SHARED: defaultdict(_SharedScopeResolver),
124+
}
127125

128126

129127
@asynccontextmanager
@@ -150,15 +148,6 @@ def define_scope(
150148
yield facade
151149

152150

153-
def get_active_scopes(name: str) -> tuple[Scope, ...]:
154-
active_scopes = (
155-
state.active_scopes
156-
for states in (__CONTEXTUAL_SCOPES, __SHARED_SCOPES)
157-
if (state := states.get(name))
158-
)
159-
return tuple(itertools.chain.from_iterable(active_scopes))
160-
161-
162151
if TYPE_CHECKING: # pragma: no cover
163152

164153
@overload
@@ -169,9 +158,9 @@ def get_scope[T](name: str, default: T) -> Scope | T: ...
169158

170159

171160
def get_scope[T](name: str, default: T | EllipsisType = ...) -> Scope | T:
172-
for states in (__CONTEXTUAL_SCOPES, __SHARED_SCOPES):
173-
state = states.get(name)
174-
if state and (scope := state.get_scope()):
161+
for resolvers in __scope_resolvers.values():
162+
resolver = resolvers.get(name)
163+
if resolver and (scope := resolver.get_scope()):
175164
return scope
176165

177166
if default is Ellipsis:
@@ -183,12 +172,16 @@ def get_scope[T](name: str, default: T | EllipsisType = ...) -> Scope | T:
183172

184173

185174
def in_scope_cache(key: SlotKey[Any], scope_name: str) -> bool:
186-
return any(key in scope.cache for scope in get_active_scopes(scope_name))
175+
return any(key in scope.cache for scope in iter_active_scopes(scope_name))
187176

188177

189-
def remove_scoped_values(key: SlotKey[Any], scope_name: str) -> None:
190-
for scope in get_active_scopes(scope_name):
191-
scope.cache.pop(key, None)
178+
def iter_active_scopes(name: str) -> Iterator[Scope]:
179+
active_scopes = (
180+
resolver.active_scopes
181+
for resolvers in __scope_resolvers.values()
182+
if (resolver := resolvers.get(name))
183+
)
184+
return itertools.chain.from_iterable(active_scopes)
192185

193186

194187
@contextmanager
@@ -198,28 +191,17 @@ def _bind_scope(
198191
kind: ScopeKind | ScopeKindStr,
199192
threadsafe: bool | None,
200193
) -> Iterator[ScopeFacade]:
194+
kind = ScopeKind(kind)
201195
lock = get_lock(threadsafe)
202196

203197
with lock:
204-
match ScopeKind(kind):
205-
case ScopeKind.CONTEXTUAL:
206-
is_already_defined = bool(get_scope(name, default=None))
207-
states = __CONTEXTUAL_SCOPES
208-
209-
case ScopeKind.SHARED:
210-
is_already_defined = bool(get_active_scopes(name))
211-
states = __SHARED_SCOPES
212-
213-
case _:
214-
raise NotImplementedError
215-
216-
if is_already_defined:
198+
if get_scope(name, default=None):
217199
raise ScopeAlreadyDefinedError(
218200
f"Scope `{name}` is already defined in the current context."
219201
)
220202

221203
stack = ExitStack()
222-
stack.enter_context(states[name].bind(scope))
204+
stack.enter_context(__scope_resolvers[kind][name].bind(scope))
223205

224206
try:
225207
yield _UserScope(scope, lock)

tests/core/test_module.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -452,17 +452,10 @@ class Dependency: ...
452452
assert module.is_locked is False
453453

454454
with define_scope("test"):
455-
instance_1 = module.get_instance(Dependency)
456-
assert module.is_locked is True
457-
458-
module.unlock()
459-
assert module.is_locked is False
460-
461-
instance_2 = module.get_instance(Dependency)
462-
assert module.is_locked is True
455+
module.get_instance(Dependency)
463456

464-
assert instance_1 is not instance_2
465-
assert module.is_locked is False
457+
with pytest.raises(RuntimeError):
458+
module.unlock()
466459

467460
def test_unlock_with_scoped_cm_recipe(self, module):
468461
class Dependency: ...
@@ -471,6 +464,8 @@ class Dependency: ...
471464
def dependency_recipe() -> Iterator[Dependency]:
472465
yield Dependency()
473466

467+
assert module.is_locked is False
468+
474469
with define_scope("test"):
475470
module.get_instance(Dependency)
476471

0 commit comments

Comments
 (0)