|
| 1 | +import inspect |
1 | 2 | import typing as t |
2 | 3 |
|
3 | 4 | from ellar.common.interfaces import IEllarMiddleware |
4 | 5 | from ellar.common.types import ASGIApp |
5 | | -from ellar.core.context import current_injector |
6 | | -from ellar.di import injectable |
7 | | -from ellar.utils import build_init_kwargs |
| 6 | +from ellar.core.execution_context import current_injector |
| 7 | +from ellar.utils.importer import import_from_string |
| 8 | +from injector import _infer_injected_bindings |
8 | 9 | from starlette.middleware import Middleware |
9 | 10 |
|
10 | 11 | T = t.TypeVar("T") |
11 | 12 |
|
12 | 13 |
|
13 | 14 | class EllarMiddleware(Middleware, IEllarMiddleware): |
14 | | - _provider_token: t.Optional[str] |
15 | | - |
16 | 15 | @t.no_type_check |
17 | 16 | def __init__( |
18 | 17 | self, |
19 | | - cls: t.Type[T], |
20 | | - provider_token: t.Optional[str] = None, |
| 18 | + cls_or_import_string: t.Union[t.Type[T], str], |
21 | 19 | **options: t.Any, |
22 | 20 | ) -> None: |
23 | | - super().__init__(cls, **options) |
24 | | - injectable()(self.cls) |
25 | | - self.kwargs = build_init_kwargs(self.cls, self.kwargs) |
26 | | - self._provider_token = provider_token |
27 | | - |
28 | | - def _register_middleware(self) -> None: |
29 | | - provider_token = self._provider_token |
30 | | - if provider_token: |
31 | | - module_data = next( |
32 | | - current_injector.tree_manager.find_module( |
33 | | - lambda data: data.name == provider_token |
34 | | - ) |
35 | | - ) |
36 | | - |
37 | | - if module_data and module_data.is_ready: |
38 | | - module_data.value.add_provider(self.cls, export=True) |
39 | | - return |
| 21 | + super().__init__(cls_or_import_string, **options) |
40 | 22 |
|
41 | | - current_injector.tree_manager.get_root_module().add_provider( |
42 | | - self.cls, export=True |
43 | | - ) |
| 23 | + def _ensure_class(self) -> None: |
| 24 | + if isinstance(self.cls, str): |
| 25 | + self.cls = import_from_string(self.cls) |
44 | 26 |
|
45 | 27 | def __iter__(self) -> t.Iterator[t.Any]: |
| 28 | + self._ensure_class() |
46 | 29 | as_tuple = (self, self.args, self.kwargs) |
47 | 30 | return iter(as_tuple) |
48 | 31 |
|
| 32 | + def create_object(self, **init_kwargs: t.Any) -> t.Any: |
| 33 | + _result = dict(init_kwargs) |
| 34 | + |
| 35 | + if hasattr(self.cls, "__init__"): |
| 36 | + spec = inspect.signature(self.cls.__init__) |
| 37 | + type_hints = _infer_injected_bindings( |
| 38 | + self.cls.__init__, only_explicit_bindings=False |
| 39 | + ) |
| 40 | + |
| 41 | + for k, annotation in type_hints.items(): |
| 42 | + parameter = spec.parameters.get(k) |
| 43 | + if k in _result or (parameter and parameter.default is None): |
| 44 | + continue |
| 45 | + |
| 46 | + _result[k] = current_injector.get(annotation) |
| 47 | + |
| 48 | + return self.cls(**_result) |
| 49 | + |
49 | 50 | @t.no_type_check |
50 | 51 | def __call__(self, app: ASGIApp, *args: t.Any, **kwargs: t.Any) -> T: |
51 | | - self._register_middleware() |
52 | | - kwargs.update(app=app) |
| 52 | + self._ensure_class() |
| 53 | + # kwargs.update(app=app) |
53 | 54 | try: |
54 | | - return current_injector.create_object(self.cls, additional_kwargs=kwargs) |
| 55 | + return self.create_object(**kwargs, app=app) |
55 | 56 | except TypeError: # pragma: no cover |
56 | 57 | # TODO: Fix future typing for lower python version. |
57 | | - return self.cls(*args, **kwargs) |
| 58 | + return self.cls(*args, **kwargs, app=app) |
0 commit comments