Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions injection/_core/injectables.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,7 @@

from injection._core.common.asynchronous import AsyncSemaphore, Caller
from injection._core.common.type import InputType
from injection._core.scope import (
Scope,
get_scope,
in_scope_cache,
remove_scoped_values,
)
from injection._core.scope import Scope, get_scope, in_scope_cache
from injection._core.slots import SlotKey
from injection.exceptions import EmptySlotError, InjectionError

Expand Down Expand Up @@ -204,9 +199,6 @@ async def abuild(self, scope: Scope) -> T:
def build(self, scope: Scope) -> T:
return self.factory.call()

def unlock(self) -> None:
remove_scoped_values(self.key, self.scope_name)


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class ScopedSlotInjectable[T](Injectable[T]):
Expand Down
76 changes: 29 additions & 47 deletions injection/_core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import itertools
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import AsyncIterator, Iterator, Mapping, MutableMapping
from collections.abc import AsyncIterator, Collection, Iterator, Mapping, MutableMapping
from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field
Expand Down Expand Up @@ -47,12 +47,12 @@ def get_default(cls) -> ScopeKind:


@runtime_checkable
class ScopeState(Protocol):
class ScopeResolver(Protocol):
__slots__ = ()

@property
@abstractmethod
def active_scopes(self) -> Iterator[Scope]:
def active_scopes(self) -> Collection[Scope]:
raise NotImplementedError

@abstractmethod
Expand All @@ -65,8 +65,8 @@ def get_scope(self) -> Scope | None:


@dataclass(repr=False, frozen=True, slots=True)
class _ContextualScopeState(ScopeState):
# Shouldn't be instantiated outside `__CONTEXTUAL_SCOPES`.
class _ContextualScopeResolver(ScopeResolver):
# Shouldn't be instantiated outside `__scope_resolvers`.

__context_var: ContextVar[Scope] = field(
default_factory=lambda: ContextVar(f"scope@{new_short_key()}"),
Expand All @@ -78,8 +78,8 @@ class _ContextualScopeState(ScopeState):
)

@property
def active_scopes(self) -> Iterator[Scope]:
return iter(self.__references)
def active_scopes(self) -> Collection[Scope]:
return self.__references

@contextmanager
def bind(self, scope: Scope) -> Iterator[None]:
Expand All @@ -97,13 +97,13 @@ def get_scope(self) -> Scope | None:


@dataclass(repr=False, slots=True)
class _SharedScopeState(ScopeState):
class _SharedScopeResolver(ScopeResolver):
__scope: Scope | None = field(default=None)

@property
def active_scopes(self) -> Iterator[Scope]:
if scope := self.__scope:
yield scope
def active_scopes(self) -> Collection[Scope]:
scope = self.__scope
return () if scope is None else (scope,)

@contextmanager
def bind(self, scope: Scope) -> Iterator[None]:
Expand All @@ -118,12 +118,10 @@ def get_scope(self) -> Scope | None:
return self.__scope


__CONTEXTUAL_SCOPES: Final[Mapping[str, ScopeState]] = defaultdict(
_ContextualScopeState,
)
__SHARED_SCOPES: Final[Mapping[str, ScopeState]] = defaultdict(
_SharedScopeState,
)
__scope_resolvers: Final[Mapping[str, Mapping[str, ScopeResolver]]] = {
ScopeKind.CONTEXTUAL: defaultdict(_ContextualScopeResolver),
ScopeKind.SHARED: defaultdict(_SharedScopeResolver),
}


@asynccontextmanager
Expand All @@ -150,15 +148,6 @@ def define_scope(
yield facade


def get_active_scopes(name: str) -> tuple[Scope, ...]:
active_scopes = (
state.active_scopes
for states in (__CONTEXTUAL_SCOPES, __SHARED_SCOPES)
if (state := states.get(name))
)
return tuple(itertools.chain.from_iterable(active_scopes))


if TYPE_CHECKING: # pragma: no cover

@overload
Expand All @@ -169,9 +158,9 @@ def get_scope[T](name: str, default: T) -> Scope | T: ...


def get_scope[T](name: str, default: T | EllipsisType = ...) -> Scope | T:
for states in (__CONTEXTUAL_SCOPES, __SHARED_SCOPES):
state = states.get(name)
if state and (scope := state.get_scope()):
for resolvers in __scope_resolvers.values():
resolver = resolvers.get(name)
if resolver and (scope := resolver.get_scope()):
return scope

if default is Ellipsis:
Expand All @@ -183,12 +172,16 @@ def get_scope[T](name: str, default: T | EllipsisType = ...) -> Scope | T:


def in_scope_cache(key: SlotKey[Any], scope_name: str) -> bool:
return any(key in scope.cache for scope in get_active_scopes(scope_name))
return any(key in scope.cache for scope in iter_active_scopes(scope_name))


def remove_scoped_values(key: SlotKey[Any], scope_name: str) -> None:
for scope in get_active_scopes(scope_name):
scope.cache.pop(key, None)
def iter_active_scopes(name: str) -> Iterator[Scope]:
active_scopes = (
resolver.active_scopes
for resolvers in __scope_resolvers.values()
if (resolver := resolvers.get(name))
)
return itertools.chain.from_iterable(active_scopes)


@contextmanager
Expand All @@ -198,28 +191,17 @@ def _bind_scope(
kind: ScopeKind | ScopeKindStr,
threadsafe: bool | None,
) -> Iterator[ScopeFacade]:
kind = ScopeKind(kind)
lock = get_lock(threadsafe)

with lock:
match ScopeKind(kind):
case ScopeKind.CONTEXTUAL:
is_already_defined = bool(get_scope(name, default=None))
states = __CONTEXTUAL_SCOPES

case ScopeKind.SHARED:
is_already_defined = bool(get_active_scopes(name))
states = __SHARED_SCOPES

case _:
raise NotImplementedError

if is_already_defined:
if get_scope(name, default=None):
raise ScopeAlreadyDefinedError(
f"Scope `{name}` is already defined in the current context."
)

stack = ExitStack()
stack.enter_context(states[name].bind(scope))
stack.enter_context(__scope_resolvers[kind][name].bind(scope))

try:
yield _UserScope(scope, lock)
Expand Down
15 changes: 5 additions & 10 deletions tests/core/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,17 +452,10 @@ class Dependency: ...
assert module.is_locked is False

with define_scope("test"):
instance_1 = module.get_instance(Dependency)
assert module.is_locked is True

module.unlock()
assert module.is_locked is False

instance_2 = module.get_instance(Dependency)
assert module.is_locked is True
module.get_instance(Dependency)

assert instance_1 is not instance_2
assert module.is_locked is False
with pytest.raises(RuntimeError):
module.unlock()

def test_unlock_with_scoped_cm_recipe(self, module):
class Dependency: ...
Expand All @@ -471,6 +464,8 @@ class Dependency: ...
def dependency_recipe() -> Iterator[Dependency]:
yield Dependency()

assert module.is_locked is False

with define_scope("test"):
module.get_instance(Dependency)

Expand Down