Skip to content

Commit 1b700a5

Browse files
committed
Some tree_manager refactor and request scope contextvar fixes
1 parent 2c1d71a commit 1b700a5

File tree

8 files changed

+57
-68
lines changed

8 files changed

+57
-68
lines changed

ellar/di/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,15 @@
77
from .constants import (
88
INJECTABLE_ATTRIBUTE,
99
MODULE_REF_TYPES,
10-
SCOPED_CONTEXT_VAR,
1110
AnnotationToValue,
11+
request_context_var,
12+
)
13+
from .injector import (
14+
Container,
15+
EllarInjector,
16+
ModuleTreeManager,
17+
register_request_scope_context,
1218
)
13-
from .injector import Container, EllarInjector, ModuleTreeManager
1419
from .scopes import (
1520
RequestORTransientScope,
1621
RequestScope,
@@ -45,12 +50,13 @@
4550
"has_binding",
4651
"get_scope",
4752
"RequestScopeContext",
48-
"SCOPED_CONTEXT_VAR",
53+
"request_context_var",
4954
"INJECTABLE_ATTRIBUTE",
5055
"AnnotationToValue",
5156
"MODULE_REF_TYPES",
5257
"InjectByTag",
5358
"ModuleTreeManager",
59+
"register_request_scope_context",
5460
]
5561

5662

ellar/di/constants.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
INJECTABLE_ATTRIBUTE = "__DI_SCOPE__"
77

88

9-
SCOPED_CONTEXT_VAR: contextvars.ContextVar[Optional[RequestScopeContext]] = (
10-
contextvars.ContextVar("SCOPED-CONTEXT-VAR")
9+
request_context_var: contextvars.ContextVar[Optional[RequestScopeContext]] = (
10+
contextvars.ContextVar("ellar.di.request_context_var")
1111
)
12-
SCOPED_CONTEXT_VAR.set(None)
12+
request_context_var.set(None)
1313
INJECTABLE_WATERMARK = "INJECTABLE_WATERMARK"
1414

1515

@@ -29,7 +29,7 @@ def __new__(mcls, name, bases, namespace):
2929
annotations.update(namespace.get("__annotations__", {}))
3030
cls.keys = []
3131
for k, v in annotations.items():
32-
if type(v) == type(str):
32+
if type(v) is type(str):
3333
value = str(k).lower()
3434
setattr(cls, k, value)
3535
cls.keys.append(value)

ellar/di/context.py

Whitespace-only changes.

ellar/di/injector/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
from .container import Container
2-
from .ellar_injector import EllarInjector
2+
from .ellar_injector import EllarInjector, register_request_scope_context
33
from .tree_manager import ModuleTreeManager
44

5-
__all__ = ["Container", "EllarInjector", "ModuleTreeManager"]
5+
__all__ = [
6+
"Container",
7+
"EllarInjector",
8+
"ModuleTreeManager",
9+
"register_request_scope_context",
10+
]

ellar/di/injector/container.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,6 @@ def get_binding(self, interface: t.Type) -> t.Tuple[Binding, InjectorBinder]:
167167
return super().get_binding(interface)
168168
except (KeyError, UnsatisfiedRequirement) as uex:
169169
try:
170-
# All services or providers must have MODULE_SCOPE_OWNER if added in @Module 'providers',
171-
# 'controllers' and 'routers'.
172-
# So if 'module_scope_owner' is None, it means `Interface` was not registered to any module
173170
if self.injector.owner:
174171
module_name = (
175172
self.injector.owner.name if self.injector.owner else None
@@ -184,23 +181,6 @@ def get_binding(self, interface: t.Type) -> t.Tuple[Binding, InjectorBinder]:
184181
# TODO: possible circular import
185182
return module_owner.value.container._get_binding(interface)
186183

187-
# current_node = self.injector.tree_manager.find_module(
188-
# predicate=lambda n: n.parent == self.injector.module_name
189-
# )
190-
191-
# root_module_name = self.injector.tree_manager.get_root_module().name
192-
#
193-
# # module_data = self.injector.tree_manager.search_module_tree(
194-
# # filter_item=lambda data: data.name == module_name,
195-
# # find_predicate=lambda data: interface in data.exports,
196-
# # )
197-
# module_data = self.injector.tree_manager.search_module_tree(
198-
# filter_item=lambda data: data.name == module_name or data.name == root_module_name,
199-
# find_predicate=lambda data: interface in data.exports,
200-
# )
201-
# if module_data:
202-
# return module_data.value.container._get_binding(interface)
203-
204184
except (KeyError, UnsatisfiedRequirement, Exception) as ex:
205185
logging.exception(ex)
206186
raise UnsatisfiedRequirement(None, interface) from uex

ellar/di/injector/ellar_injector.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
import sys
22
import typing as t
3-
from contextlib import asynccontextmanager
43
from functools import cached_property
54

6-
from ellar.di.constants import MODULE_REF_TYPES, SCOPED_CONTEXT_VAR, Tag
5+
from ellar.di.constants import MODULE_REF_TYPES, Tag, request_context_var
76
from ellar.di.injector.tree_manager import ModuleTreeManager
87
from ellar.di.logger import log
98
from injector import Injector, Scope, ScopeDecorator
109
from typing_extensions import Annotated
1110

12-
from ..asgi_args import RequestScopeContext
1311
from ..providers import InstanceProvider, Provider
1412
from ..types import T
1513
from .container import Container
@@ -22,6 +20,19 @@
2220
)
2321

2422

23+
def register_request_scope_context(interface: t.Type[T], value: T) -> None:
24+
# Sets RequestScope contexts so that they can be available when needed
25+
26+
scoped_context = request_context_var.get()
27+
if scoped_context is None:
28+
return
29+
30+
if isinstance(value, Provider):
31+
scoped_context.context.update({interface: value})
32+
else:
33+
scoped_context.context.update({interface: InstanceProvider(value)})
34+
35+
2536
class _TagInfo(t.NamedTuple):
2637
supertype: t.Type
2738
tag: str
@@ -124,23 +135,3 @@ def get(
124135
)
125136
log.debug(f"{self._log_prefix} -> {result}")
126137
return t.cast(T, result)
127-
128-
def update_scoped_context(self, interface: t.Type[T], value: T) -> None:
129-
# Sets RequestScope contexts so that they can be available when needed
130-
# TODO: Rename to `register_request_scope_context`
131-
scoped_context = SCOPED_CONTEXT_VAR.get()
132-
if scoped_context is None:
133-
return
134-
135-
if isinstance(value, Provider):
136-
scoped_context.context.update({interface: value})
137-
else:
138-
scoped_context.context.update({interface: InstanceProvider(value)})
139-
140-
@asynccontextmanager
141-
async def create_asgi_args(self) -> t.AsyncGenerator["EllarInjector", None]:
142-
try:
143-
SCOPED_CONTEXT_VAR.set(RequestScopeContext())
144-
yield self
145-
finally:
146-
SCOPED_CONTEXT_VAR.set(None)

ellar/di/injector/tree_manager.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ def exports(self) -> t.List[t.Type]:
3030
return self.value.exports
3131

3232
@property
33-
def providers(self) -> t.Dict[t.Type, t.Type]:
34-
return self.value.providers
33+
def providers(self) -> t.Dict[t.Type, t.Union[t.Type, "ProviderConfig"]]:
34+
return self.value.providers # type:ignore[return-value]
3535

3636
@property
3737
def name(self) -> str:
@@ -45,7 +45,7 @@ def __str__(self) -> str:
4545

4646

4747
class ModuleTreeManager:
48-
__slots__ = ("modules", "_core_module", "_root_module")
48+
__slots__ = ("modules", "_core_module", "_app_module")
4949

5050
# , root_module: t.Union["ModuleRefBase", "ModuleSetup"]
5151
def __init__(
@@ -57,15 +57,16 @@ def __init__(
5757
) # Dictionary to store modules by their ID or value
5858

5959
self._core_module = app_core_module.module if app_core_module else None
60-
self._root_module: t.Optional[t.Type[t.Any]] = None
60+
self._app_module: t.Optional[t.Type[t.Any]] = None
6161

6262
if app_core_module:
6363
self.add_module(app_core_module.module, value=app_core_module)
6464

6565
@property
6666
def root_module(self) -> t.Type:
67-
assert self._root_module is not None, "RootModule is not ready"
68-
return self._root_module
67+
root_module = self._core_module or self._app_module
68+
assert root_module is not None, "RootModule is not ready"
69+
return root_module
6970

7071
def add_provider(
7172
self,
@@ -105,19 +106,19 @@ def add_module(
105106

106107
self.modules[parent_module].dependencies.append(module_type)
107108

108-
if parent_module == self._core_module and not self._root_module:
109-
self._root_module = data.value.module
109+
if parent_module == self._core_module and not self._app_module:
110+
self._app_module = data.value.module
110111

111-
elif not parent_module and not self._root_module and not self._core_module:
112-
self._root_module = module_type
112+
elif not parent_module and not self._app_module and not self._core_module:
113+
self._app_module = module_type
113114
elif (
114-
self._root_module
115+
self._app_module
115116
and parent_module
116117
and self._core_module
117118
and parent_module == self._core_module
118119
):
119120
raise Exception(
120-
f"EllarCoreModule can only have '{self._root_module}' as dependency"
121+
f"EllarCoreModule can only have '{self._app_module}' as dependency"
121122
)
122123
return self
123124

@@ -204,8 +205,9 @@ def find_module(
204205
if not found_any:
205206
yield None # type:ignore[misc]
206207

207-
def get_root_module(self) -> t.Union["ModuleRefBase", "ModuleSetup"]:
208-
data = self.get_module(self.root_module)
208+
def get_app_module(self) -> t.Union["ModuleRefBase", "ModuleSetup"]:
209+
assert self._app_module is not None, "AppModule is not ready"
210+
data = self.get_module(self._app_module)
209211
assert data
210212
return data.value
211213

ellar/di/scopes.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import logging
12
import typing as t
23

3-
from ellar.di import SCOPED_CONTEXT_VAR, RequestScopeContext
4+
from ellar.di import RequestScopeContext, request_context_var
45
from injector import (
56
NoScope as TransientScope,
67
)
@@ -23,7 +24,11 @@
2324

2425
class RequestScope(InjectorScope):
2526
def get_context(self) -> t.Optional[RequestScopeContext]:
26-
return SCOPED_CONTEXT_VAR.get()
27+
try:
28+
return request_context_var.get()
29+
except Exception as ex:
30+
logging.exception(ex)
31+
return None
2732

2833
def get(self, key: t.Type[T], provider: Provider[T]) -> Provider[T]:
2934
scoped_context = self.get_context()

0 commit comments

Comments
 (0)