diff --git a/injection/_core/injectables.py b/injection/_core/injectables.py index 79f4743..c14ff28 100644 --- a/injection/_core/injectables.py +++ b/injection/_core/injectables.py @@ -16,12 +16,7 @@ from injection._core.common.asynchronous import AsyncSemaphore, Caller from injection._core.common.type import InputType -from injection._core.scope import ( - Scope, - get_scope, - in_scope_cache, - remove_scoped_values, -) +from injection._core.scope import Scope, get_scope, in_scope_cache from injection._core.slots import SlotKey from injection.exceptions import EmptySlotError, InjectionError @@ -204,9 +199,6 @@ async def abuild(self, scope: Scope) -> T: def build(self, scope: Scope) -> T: return self.factory.call() - def unlock(self) -> None: - remove_scoped_values(self.key, self.scope_name) - @dataclass(repr=False, eq=False, frozen=True, slots=True) class ScopedSlotInjectable[T](Injectable[T]): diff --git a/injection/_core/scope.py b/injection/_core/scope.py index 88c41e4..70e2849 100644 --- a/injection/_core/scope.py +++ b/injection/_core/scope.py @@ -3,7 +3,7 @@ import itertools from abc import ABC, abstractmethod from collections import defaultdict -from collections.abc import AsyncIterator, Iterator, Mapping, MutableMapping +from collections.abc import AsyncIterator, Collection, Iterator, Mapping, MutableMapping from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager from contextvars import ContextVar from dataclasses import dataclass, field @@ -47,12 +47,12 @@ def get_default(cls) -> ScopeKind: @runtime_checkable -class ScopeState(Protocol): +class ScopeResolver(Protocol): __slots__ = () @property @abstractmethod - def active_scopes(self) -> Iterator[Scope]: + def active_scopes(self) -> Collection[Scope]: raise NotImplementedError @abstractmethod @@ -65,8 +65,8 @@ def get_scope(self) -> Scope | None: @dataclass(repr=False, frozen=True, slots=True) -class _ContextualScopeState(ScopeState): - # Shouldn't be instantiated outside `__CONTEXTUAL_SCOPES`. +class _ContextualScopeResolver(ScopeResolver): + # Shouldn't be instantiated outside `__scope_resolvers`. __context_var: ContextVar[Scope] = field( default_factory=lambda: ContextVar(f"scope@{new_short_key()}"), @@ -78,8 +78,8 @@ class _ContextualScopeState(ScopeState): ) @property - def active_scopes(self) -> Iterator[Scope]: - return iter(self.__references) + def active_scopes(self) -> Collection[Scope]: + return self.__references @contextmanager def bind(self, scope: Scope) -> Iterator[None]: @@ -97,13 +97,13 @@ def get_scope(self) -> Scope | None: @dataclass(repr=False, slots=True) -class _SharedScopeState(ScopeState): +class _SharedScopeResolver(ScopeResolver): __scope: Scope | None = field(default=None) @property - def active_scopes(self) -> Iterator[Scope]: - if scope := self.__scope: - yield scope + def active_scopes(self) -> Collection[Scope]: + scope = self.__scope + return () if scope is None else (scope,) @contextmanager def bind(self, scope: Scope) -> Iterator[None]: @@ -118,12 +118,10 @@ def get_scope(self) -> Scope | None: return self.__scope -__CONTEXTUAL_SCOPES: Final[Mapping[str, ScopeState]] = defaultdict( - _ContextualScopeState, -) -__SHARED_SCOPES: Final[Mapping[str, ScopeState]] = defaultdict( - _SharedScopeState, -) +__scope_resolvers: Final[Mapping[str, Mapping[str, ScopeResolver]]] = { + ScopeKind.CONTEXTUAL: defaultdict(_ContextualScopeResolver), + ScopeKind.SHARED: defaultdict(_SharedScopeResolver), +} @asynccontextmanager @@ -150,15 +148,6 @@ def define_scope( yield facade -def get_active_scopes(name: str) -> tuple[Scope, ...]: - active_scopes = ( - state.active_scopes - for states in (__CONTEXTUAL_SCOPES, __SHARED_SCOPES) - if (state := states.get(name)) - ) - return tuple(itertools.chain.from_iterable(active_scopes)) - - if TYPE_CHECKING: # pragma: no cover @overload @@ -169,9 +158,9 @@ def get_scope[T](name: str, default: T) -> Scope | T: ... def get_scope[T](name: str, default: T | EllipsisType = ...) -> Scope | T: - for states in (__CONTEXTUAL_SCOPES, __SHARED_SCOPES): - state = states.get(name) - if state and (scope := state.get_scope()): + for resolvers in __scope_resolvers.values(): + resolver = resolvers.get(name) + if resolver and (scope := resolver.get_scope()): return scope if default is Ellipsis: @@ -183,12 +172,16 @@ def get_scope[T](name: str, default: T | EllipsisType = ...) -> Scope | T: def in_scope_cache(key: SlotKey[Any], scope_name: str) -> bool: - return any(key in scope.cache for scope in get_active_scopes(scope_name)) + return any(key in scope.cache for scope in iter_active_scopes(scope_name)) -def remove_scoped_values(key: SlotKey[Any], scope_name: str) -> None: - for scope in get_active_scopes(scope_name): - scope.cache.pop(key, None) +def iter_active_scopes(name: str) -> Iterator[Scope]: + active_scopes = ( + resolver.active_scopes + for resolvers in __scope_resolvers.values() + if (resolver := resolvers.get(name)) + ) + return itertools.chain.from_iterable(active_scopes) @contextmanager @@ -198,28 +191,17 @@ def _bind_scope( kind: ScopeKind | ScopeKindStr, threadsafe: bool | None, ) -> Iterator[ScopeFacade]: + kind = ScopeKind(kind) lock = get_lock(threadsafe) with lock: - match ScopeKind(kind): - case ScopeKind.CONTEXTUAL: - is_already_defined = bool(get_scope(name, default=None)) - states = __CONTEXTUAL_SCOPES - - case ScopeKind.SHARED: - is_already_defined = bool(get_active_scopes(name)) - states = __SHARED_SCOPES - - case _: - raise NotImplementedError - - if is_already_defined: + if get_scope(name, default=None): raise ScopeAlreadyDefinedError( f"Scope `{name}` is already defined in the current context." ) stack = ExitStack() - stack.enter_context(states[name].bind(scope)) + stack.enter_context(__scope_resolvers[kind][name].bind(scope)) try: yield _UserScope(scope, lock) diff --git a/tests/core/test_module.py b/tests/core/test_module.py index a6ddeb5..6b7c21b 100644 --- a/tests/core/test_module.py +++ b/tests/core/test_module.py @@ -452,17 +452,10 @@ class Dependency: ... assert module.is_locked is False with define_scope("test"): - instance_1 = module.get_instance(Dependency) - assert module.is_locked is True - - module.unlock() - assert module.is_locked is False - - instance_2 = module.get_instance(Dependency) - assert module.is_locked is True + module.get_instance(Dependency) - assert instance_1 is not instance_2 - assert module.is_locked is False + with pytest.raises(RuntimeError): + module.unlock() def test_unlock_with_scoped_cm_recipe(self, module): class Dependency: ... @@ -471,6 +464,8 @@ class Dependency: ... def dependency_recipe() -> Iterator[Dependency]: yield Dependency() + assert module.is_locked is False + with define_scope("test"): module.get_instance(Dependency)