Skip to content

Commit ca4fea1

Browse files
author
remimd
committed
wip
1 parent b2f1800 commit ca4fea1

File tree

5 files changed

+125
-98
lines changed

5 files changed

+125
-98
lines changed

injection/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
11
from ._core.descriptors import LazyInstance
22
from ._core.injectables import Injectable
33
from ._core.module import Mode, Module, Priority, mod
4-
from ._core.scope import AsyncScope, SyncScope
4+
from ._core.scope import async_scope, sync_scope
55

66
__all__ = (
7-
"AsyncScope",
87
"Injectable",
98
"LazyInstance",
109
"Mode",
1110
"Module",
1211
"Priority",
13-
"SyncScope",
1412
"afind_instance",
1513
"aget_instance",
1614
"aget_lazy_instance",
15+
"async_scope",
1716
"constant",
1817
"find_instance",
1918
"get_instance",
@@ -25,6 +24,7 @@
2524
"set_constant",
2625
"should_be_injectable",
2726
"singleton",
27+
"sync_scope",
2828
)
2929

3030
afind_instance = mod().afind_instance

injection/__init__.pyi

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from abc import abstractmethod
22
from collections.abc import Awaitable, Callable
3-
from contextlib import AsyncContextDecorator, ContextDecorator
3+
from contextlib import ContextDecorator
44
from enum import Enum
55
from logging import Logger
6-
from types import TracebackType
76
from typing import (
87
Any,
8+
AsyncContextManager,
99
ContextManager,
1010
Final,
1111
Protocol,
@@ -37,6 +37,8 @@ set_constant = __MODULE.set_constant
3737
should_be_injectable = __MODULE.should_be_injectable
3838
singleton = __MODULE.singleton
3939

40+
def async_scope(name: str, *, shared: bool = ...) -> AsyncContextManager[None]: ...
41+
def sync_scope(name: str, *, shared: bool = ...) -> ContextManager[None]: ...
4042
def mod(name: str = ..., /) -> Module:
4143
"""
4244
Short syntax for `Module.from_name`.
@@ -318,23 +320,3 @@ class Mode(Enum):
318320
class Priority(Enum):
319321
LOW = ...
320322
HIGH = ...
321-
322-
class AsyncScope(AsyncContextDecorator):
323-
def __init__(self, name: str) -> None: ...
324-
async def __aenter__(self) -> Self: ...
325-
async def __aexit__(
326-
self,
327-
exc_type: type[BaseException] | None,
328-
exc_value: BaseException | None,
329-
traceback: TracebackType | None,
330-
) -> Any: ...
331-
332-
class SyncScope(ContextDecorator):
333-
def __init__(self, name: str) -> None: ...
334-
def __enter__(self) -> Self: ...
335-
def __exit__(
336-
self,
337-
exc_type: type[BaseException] | None,
338-
exc_value: BaseException | None,
339-
traceback: TracebackType | None,
340-
) -> Any: ...

injection/_core/scope.py

Lines changed: 103 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from __future__ import annotations
22

3-
from abc import abstractmethod
3+
from abc import ABC, abstractmethod
44
from collections import defaultdict
5-
from collections.abc import Iterator, MutableMapping
5+
from collections.abc import AsyncIterator, Iterator, MutableMapping
66
from contextlib import (
77
AsyncContextDecorator,
88
AsyncExitStack,
99
ContextDecorator,
1010
ExitStack,
11+
asynccontextmanager,
1112
contextmanager,
1213
)
1314
from contextvars import ContextVar
@@ -31,59 +32,111 @@
3132
)
3233

3334

34-
@dataclass(repr=False, frozen=True, slots=True)
35-
class _ActiveScope:
35+
@dataclass(repr=False, slots=True)
36+
class _ScopeState:
3637
# Shouldn't be instantiated outside `__SCOPES`.
3738

38-
context_var: ContextVar[Scope] = field(
39+
__context_var: ContextVar[Scope] = field(
3940
default_factory=lambda: ContextVar(f"scope@{new_short_key()}"),
4041
init=False,
4142
)
42-
references: set[Scope] = field(
43+
__references: set[Scope] = field(
4344
default_factory=set,
4445
init=False,
4546
)
47+
__shared_value: Scope | None = field(
48+
default=None,
49+
init=False,
50+
)
4651

47-
def to_tuple(self) -> tuple[ContextVar[Scope], set[Scope]]:
48-
return self.context_var, self.references
52+
@contextmanager
53+
def bind_contextual_scope(self, scope: Scope) -> Iterator[None]:
54+
self.__references.add(scope)
55+
token = self.__context_var.set(scope)
4956

57+
try:
58+
yield
59+
finally:
60+
self.__context_var.reset(token)
61+
self.__references.remove(scope)
5062

51-
__SCOPES: Final[defaultdict[str, _ActiveScope]] = defaultdict(_ActiveScope)
63+
@contextmanager
64+
def bind_shared_scope(self, scope: Scope) -> Iterator[None]:
65+
if self.__references:
66+
raise ScopeError(
67+
"A shared scope can't be defined when one or more contextual scopes "
68+
"are defined on the same name."
69+
)
5270

71+
self.__shared_value = scope
5372

54-
@contextmanager
55-
def bind_scope(name: str, value: Scope) -> Iterator[None]:
56-
context_var, references = __SCOPES[name].to_tuple()
73+
try:
74+
yield
75+
finally:
76+
self.__shared_value = None
5777

58-
if context_var.get(None):
59-
raise ScopeAlreadyDefinedError(
60-
f"Scope `{name}` is already defined in the current context."
61-
)
78+
def get_scope(self) -> Scope | None:
79+
return self.__context_var.get(self.__shared_value)
6280

63-
references.add(value)
64-
token = context_var.set(value)
81+
def get_active_scopes(self) -> tuple[Scope, ...]:
82+
references = self.__references
6583

66-
try:
84+
if shared_value := self.__shared_value:
85+
return shared_value, *references
86+
87+
return tuple(references)
88+
89+
90+
__SCOPES: Final[defaultdict[str, _ScopeState]] = defaultdict(_ScopeState)
91+
92+
93+
@asynccontextmanager
94+
async def async_scope(name: str, *, shared: bool = False) -> AsyncIterator[None]:
95+
async with AsyncScope() as scope:
96+
scope.enter(_bind_scope(name, scope, shared))
97+
yield
98+
99+
100+
@contextmanager
101+
def sync_scope(name: str, *, shared: bool = False) -> Iterator[None]:
102+
with SyncScope() as scope:
103+
scope.enter(_bind_scope(name, scope, shared))
67104
yield
68-
finally:
69-
context_var.reset(token)
70-
references.discard(value)
71-
value.cache.clear()
72105

73106

74107
def get_active_scopes(name: str) -> tuple[Scope, ...]:
75-
return tuple(__SCOPES[name].references)
108+
return __SCOPES[name].get_active_scopes()
76109

77110

78111
def get_scope(name: str) -> Scope:
79-
context_var = __SCOPES[name].context_var
112+
scope = __SCOPES[name].get_scope()
80113

81-
try:
82-
return context_var.get()
83-
except LookupError as exc:
114+
if scope is None:
84115
raise ScopeUndefinedError(
85116
f"Scope `{name}` isn't defined in the current context."
86-
) from exc
117+
)
118+
119+
return scope
120+
121+
122+
@contextmanager
123+
def _bind_scope(name: str, value: Scope, shared: bool) -> Iterator[None]:
124+
state = __SCOPES[name]
125+
126+
if state.get_scope():
127+
raise ScopeAlreadyDefinedError(
128+
f"Scope `{name}` is already defined in the current context."
129+
)
130+
131+
strategy = (
132+
state.bind_shared_scope(value) if shared else state.bind_contextual_scope(value)
133+
)
134+
135+
try:
136+
with strategy:
137+
yield
138+
finally:
139+
value.cache.clear()
87140

88141

89142
@runtime_checkable
@@ -102,22 +155,23 @@ def enter[T](self, context_manager: ContextManager[T]) -> T:
102155

103156

104157
@dataclass(repr=False, frozen=True, slots=True)
105-
class AsyncScope(AsyncContextDecorator, Scope):
106-
name: str
158+
class BaseScope[T](Scope, ABC):
159+
delegate: T
107160
cache: MutableMapping[Any, Any] = field(
108161
default_factory=dict,
109162
init=False,
110163
hash=False,
111164
)
112-
__exit_stack: AsyncExitStack = field(
113-
default_factory=AsyncExitStack,
114-
init=False,
115-
)
165+
166+
167+
class AsyncScope(AsyncContextDecorator, BaseScope[AsyncExitStack]):
168+
__slots__ = ()
169+
170+
def __init__(self) -> None:
171+
super().__init__(delegate=AsyncExitStack())
116172

117173
async def __aenter__(self) -> Self:
118-
await self.__exit_stack.__aenter__()
119-
lifespan = bind_scope(self.name, self)
120-
self.enter(lifespan)
174+
await self.delegate.__aenter__()
121175
return self
122176

123177
async def __aexit__(
@@ -126,32 +180,23 @@ async def __aexit__(
126180
exc_value: BaseException | None,
127181
traceback: TracebackType | None,
128182
) -> Any:
129-
return await self.__exit_stack.__aexit__(exc_type, exc_value, traceback)
183+
return await self.delegate.__aexit__(exc_type, exc_value, traceback)
130184

131185
async def aenter[T](self, context_manager: AsyncContextManager[T]) -> T:
132-
return await self.__exit_stack.enter_async_context(context_manager)
186+
return await self.delegate.enter_async_context(context_manager)
133187

134188
def enter[T](self, context_manager: ContextManager[T]) -> T:
135-
return self.__exit_stack.enter_context(context_manager)
189+
return self.delegate.enter_context(context_manager)
136190

137191

138-
@dataclass(repr=False, frozen=True, slots=True)
139-
class SyncScope(ContextDecorator, Scope):
140-
name: str
141-
cache: MutableMapping[Any, Any] = field(
142-
default_factory=dict,
143-
init=False,
144-
hash=False,
145-
)
146-
__exit_stack: ExitStack = field(
147-
default_factory=ExitStack,
148-
init=False,
149-
)
192+
class SyncScope(ContextDecorator, BaseScope[ExitStack]):
193+
__slots__ = ()
194+
195+
def __init__(self) -> None:
196+
super().__init__(delegate=ExitStack())
150197

151198
def __enter__(self) -> Self:
152-
self.__exit_stack.__enter__()
153-
lifespan = bind_scope(self.name, self)
154-
self.enter(lifespan)
199+
self.delegate.__enter__()
155200
return self
156201

157202
def __exit__(
@@ -160,10 +205,10 @@ def __exit__(
160205
exc_value: BaseException | None,
161206
traceback: TracebackType | None,
162207
) -> Any:
163-
return self.__exit_stack.__exit__(exc_type, exc_value, traceback)
208+
return self.delegate.__exit__(exc_type, exc_value, traceback)
164209

165210
async def aenter[T](self, context_manager: AsyncContextManager[T]) -> T:
166211
raise ScopeError("SyncScope doesn't support asynchronous context manager.")
167212

168213
def enter[T](self, context_manager: ContextManager[T]) -> T:
169-
return self.__exit_stack.enter_context(context_manager)
214+
return self.delegate.enter_context(context_manager)

tests/core/test_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from injection import Module, SyncScope
5+
from injection import Module, sync_scope
66
from injection.exceptions import (
77
ModuleError,
88
ModuleLockError,
@@ -346,7 +346,7 @@ class Dependency: ...
346346

347347
assert module.is_locked is False
348348

349-
with SyncScope("test"):
349+
with sync_scope("test"):
350350
instance_1 = module.get_instance(Dependency)
351351
assert module.is_locked is True
352352

0 commit comments

Comments
 (0)