Skip to content

Commit 392c3fe

Browse files
author
remimd
committed
refactor: Improve scope management readability
1 parent 7a45cac commit 392c3fe

File tree

3 files changed

+41
-66
lines changed

3 files changed

+41
-66
lines changed

injection/_core/injectables.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,7 @@
1616

1717
from injection._core.common.asynchronous import AsyncSemaphore, Caller
1818
from injection._core.common.type import InputType
19-
from injection._core.scope import (
20-
Scope,
21-
get_scope,
22-
in_scope_cache,
23-
remove_scoped_values,
24-
)
19+
from injection._core.scope import Scope, get_scope, in_scope_cache
2520
from injection._core.slots import SlotKey
2621
from injection.exceptions import EmptySlotError, InjectionError
2722

@@ -204,9 +199,6 @@ async def abuild(self, scope: Scope) -> T:
204199
def build(self, scope: Scope) -> T:
205200
return self.factory.call()
206201

207-
def unlock(self) -> None:
208-
remove_scoped_values(self.key, self.scope_name)
209-
210202

211203
@dataclass(repr=False, eq=False, frozen=True, slots=True)
212204
class ScopedSlotInjectable[T](Injectable[T]):

injection/_core/scope.py

Lines changed: 35 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import itertools
44
from abc import ABC, abstractmethod
55
from collections import defaultdict
6-
from collections.abc import AsyncIterator, Iterator, Mapping, MutableMapping
6+
from collections.abc import (
7+
AsyncIterator,
8+
Collection,
9+
Iterator,
10+
Mapping,
11+
MutableMapping,
12+
)
713
from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
814
from contextvars import ContextVar
915
from dataclasses import dataclass, field
@@ -47,12 +53,12 @@ def get_default(cls) -> ScopeKind:
4753

4854

4955
@runtime_checkable
50-
class ScopeState(Protocol):
56+
class ScopeResolver(Protocol):
5157
__slots__ = ()
5258

5359
@property
5460
@abstractmethod
55-
def active_scopes(self) -> Iterator[Scope]:
61+
def active_scopes(self) -> Collection[Scope]:
5662
raise NotImplementedError
5763

5864
@abstractmethod
@@ -65,8 +71,8 @@ def get_scope(self) -> Scope | None:
6571

6672

6773
@dataclass(repr=False, frozen=True, slots=True)
68-
class _ContextualScopeState(ScopeState):
69-
# Shouldn't be instantiated outside `__CONTEXTUAL_SCOPES`.
74+
class _ContextualScopeResolver(ScopeResolver):
75+
# Shouldn't be instantiated outside `__scope_resolvers`.
7076

7177
__context_var: ContextVar[Scope] = field(
7278
default_factory=lambda: ContextVar(f"scope@{new_short_key()}"),
@@ -78,8 +84,8 @@ class _ContextualScopeState(ScopeState):
7884
)
7985

8086
@property
81-
def active_scopes(self) -> Iterator[Scope]:
82-
return iter(self.__references)
87+
def active_scopes(self) -> Collection[Scope]:
88+
return self.__references
8389

8490
@contextmanager
8591
def bind(self, scope: Scope) -> Iterator[None]:
@@ -97,13 +103,13 @@ def get_scope(self) -> Scope | None:
97103

98104

99105
@dataclass(repr=False, slots=True)
100-
class _SharedScopeState(ScopeState):
106+
class _SharedScopeResolver(ScopeResolver):
101107
__scope: Scope | None = field(default=None)
102108

103109
@property
104-
def active_scopes(self) -> Iterator[Scope]:
105-
if scope := self.__scope:
106-
yield scope
110+
def active_scopes(self) -> Collection[Scope]:
111+
scope = self.__scope
112+
return () if scope is None else (scope,)
107113

108114
@contextmanager
109115
def bind(self, scope: Scope) -> Iterator[None]:
@@ -118,12 +124,10 @@ def get_scope(self) -> Scope | None:
118124
return self.__scope
119125

120126

121-
__CONTEXTUAL_SCOPES: Final[Mapping[str, ScopeState]] = defaultdict(
122-
_ContextualScopeState,
123-
)
124-
__SHARED_SCOPES: Final[Mapping[str, ScopeState]] = defaultdict(
125-
_SharedScopeState,
126-
)
127+
__scope_resolvers: Final[Mapping[str, Mapping[str, ScopeResolver]]] = {
128+
ScopeKind.CONTEXTUAL: defaultdict(_ContextualScopeResolver),
129+
ScopeKind.SHARED: defaultdict(_SharedScopeResolver),
130+
}
127131

128132

129133
@asynccontextmanager
@@ -150,15 +154,6 @@ def define_scope(
150154
yield facade
151155

152156

153-
def get_active_scopes(name: str) -> tuple[Scope, ...]:
154-
active_scopes = (
155-
state.active_scopes
156-
for states in (__CONTEXTUAL_SCOPES, __SHARED_SCOPES)
157-
if (state := states.get(name))
158-
)
159-
return tuple(itertools.chain.from_iterable(active_scopes))
160-
161-
162157
if TYPE_CHECKING: # pragma: no cover
163158

164159
@overload
@@ -169,9 +164,9 @@ def get_scope[T](name: str, default: T) -> Scope | T: ...
169164

170165

171166
def get_scope[T](name: str, default: T | EllipsisType = ...) -> Scope | T:
172-
for states in (__CONTEXTUAL_SCOPES, __SHARED_SCOPES):
173-
state = states.get(name)
174-
if state and (scope := state.get_scope()):
167+
for resolvers in __scope_resolvers.values():
168+
resolver = resolvers.get(name)
169+
if resolver and (scope := resolver.get_scope()):
175170
return scope
176171

177172
if default is Ellipsis:
@@ -183,12 +178,16 @@ def get_scope[T](name: str, default: T | EllipsisType = ...) -> Scope | T:
183178

184179

185180
def in_scope_cache(key: SlotKey[Any], scope_name: str) -> bool:
186-
return any(key in scope.cache for scope in get_active_scopes(scope_name))
181+
return any(key in scope.cache for scope in iter_active_scopes(scope_name))
187182

188183

189-
def remove_scoped_values(key: SlotKey[Any], scope_name: str) -> None:
190-
for scope in get_active_scopes(scope_name):
191-
scope.cache.pop(key, None)
184+
def iter_active_scopes(name: str) -> Iterator[Scope]:
185+
active_scopes = (
186+
resolver.active_scopes
187+
for resolvers in __scope_resolvers.values()
188+
if (resolver := resolvers.get(name))
189+
)
190+
return itertools.chain.from_iterable(active_scopes)
192191

193192

194193
@contextmanager
@@ -198,28 +197,17 @@ def _bind_scope(
198197
kind: ScopeKind | ScopeKindStr,
199198
threadsafe: bool | None,
200199
) -> Iterator[ScopeFacade]:
200+
kind = ScopeKind(kind)
201201
lock = get_lock(threadsafe)
202202

203203
with lock:
204-
match ScopeKind(kind):
205-
case ScopeKind.CONTEXTUAL:
206-
is_already_defined = bool(get_scope(name, default=None))
207-
states = __CONTEXTUAL_SCOPES
208-
209-
case ScopeKind.SHARED:
210-
is_already_defined = bool(get_active_scopes(name))
211-
states = __SHARED_SCOPES
212-
213-
case _:
214-
raise NotImplementedError
215-
216-
if is_already_defined:
204+
if next(iter_active_scopes(name), None):
217205
raise ScopeAlreadyDefinedError(
218206
f"Scope `{name}` is already defined in the current context."
219207
)
220208

221209
stack = ExitStack()
222-
stack.enter_context(states[name].bind(scope))
210+
stack.enter_context(__scope_resolvers[kind][name].bind(scope))
223211

224212
try:
225213
yield _UserScope(scope, lock)

tests/core/test_module.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -452,17 +452,10 @@ class Dependency: ...
452452
assert module.is_locked is False
453453

454454
with define_scope("test"):
455-
instance_1 = module.get_instance(Dependency)
456-
assert module.is_locked is True
457-
458-
module.unlock()
459-
assert module.is_locked is False
460-
461-
instance_2 = module.get_instance(Dependency)
462-
assert module.is_locked is True
455+
module.get_instance(Dependency)
463456

464-
assert instance_1 is not instance_2
465-
assert module.is_locked is False
457+
with pytest.raises(RuntimeError):
458+
module.unlock()
466459

467460
def test_unlock_with_scoped_cm_recipe(self, module):
468461
class Dependency: ...
@@ -471,6 +464,8 @@ class Dependency: ...
471464
def dependency_recipe() -> Iterator[Dependency]:
472465
yield Dependency()
473466

467+
assert module.is_locked is False
468+
474469
with define_scope("test"):
475470
module.get_instance(Dependency)
476471

0 commit comments

Comments
 (0)