Skip to content

Commit 65d230e

Browse files
committed
Enabled all application middlewares to be configured through application config
1 parent 939fd1b commit 65d230e

File tree

10 files changed

+136
-71
lines changed

10 files changed

+136
-71
lines changed

ellar/auth/middleware/auth.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from ellar.common import AnonymousIdentity, IHostContextFactory
33
from ellar.common.types import TReceive, TScope, TSend
44
from ellar.core.conf import Config
5+
from ellar.core.middleware import Middleware as EllarMiddleware
56
from ellar.di import EllarInjector
67
from starlette.routing import compile_path
78
from starlette.types import ASGIApp
@@ -45,3 +46,7 @@ def is_static(self, scope: TScope) -> bool:
4546
Check is the request is for a static file
4647
"""
4748
return self._path_regex.match(scope["path"]) is not None
49+
50+
51+
# IdentityMiddleware Configuration
52+
identity_middleware = EllarMiddleware(IdentityMiddleware)

ellar/auth/middleware/session.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from ellar.auth.session import SessionServiceNullStrategy, SessionStrategy
22
from ellar.core.conf import Config
3+
from ellar.core.middleware import Middleware as EllarMiddleware
34
from starlette.datastructures import MutableHeaders
45
from starlette.requests import HTTPConnection
56
from starlette.types import ASGIApp, Message, Receive, Scope, Send
@@ -9,7 +10,7 @@ class SessionMiddleware:
910
def __init__(
1011
self, app: ASGIApp, session_strategy: SessionStrategy, config: Config
1112
) -> None:
12-
config.setdefault("SESSION_DISABLED", False)
13+
config.set_defaults(SESSION_DISABLED=False)
1314
self.app = app
1415
self._session_strategy = session_strategy
1516
self._is_active = not isinstance(session_strategy, SessionServiceNullStrategy)
@@ -52,3 +53,7 @@ async def _send_wrapper(message: Message) -> None:
5253
await send(message)
5354

5455
await self.app(scope, receive, _send_wrapper)
56+
57+
58+
# SessionMiddleware Configuration
59+
session_middleware = EllarMiddleware(SessionMiddleware)

ellar/core/middleware/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
1-
from starlette.middleware.cors import CORSMiddleware as CORSMiddleware
21
from starlette.middleware.gzip import GZipMiddleware as GZipMiddleware
32
from starlette.middleware.httpsredirect import (
43
HTTPSRedirectMiddleware as HTTPSRedirectMiddleware,
54
)
6-
from starlette.middleware.trustedhost import (
7-
TrustedHostMiddleware as TrustedHostMiddleware,
8-
)
95
from starlette.middleware.wsgi import WSGIMiddleware as WSGIMiddleware
106

7+
from .cors import CORSMiddleware
118
from .errors import ServerErrorMiddleware
129
from .exceptions import ExceptionMiddleware
1310
from .function import FunctionBasedMiddleware
1411
from .middleware import EllarMiddleware as Middleware
12+
from .trusted_host import TrustedHostMiddleware
1513
from .versioning import RequestVersioningMiddleware
1614

1715
__all__ = [

ellar/core/middleware/cors.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from ellar.common.types import ASGIApp
2+
from ellar.core.conf import Config
3+
from starlette.middleware.cors import CORSMiddleware as BaseCORSMiddleware
4+
5+
from .middleware import EllarMiddleware
6+
7+
8+
class CORSMiddleware(BaseCORSMiddleware):
9+
def __init__(self, app: ASGIApp, config: Config) -> None:
10+
super().__init__(
11+
app,
12+
allow_origins=config.CORS_ALLOW_ORIGINS,
13+
allow_credentials=config.CORS_ALLOW_CREDENTIALS,
14+
allow_methods=config.CORS_ALLOW_METHODS,
15+
allow_headers=config.CORS_ALLOW_HEADERS,
16+
allow_origin_regex=config.CORS_ALLOW_ORIGIN_REGEX,
17+
expose_headers=config.CORS_EXPOSE_HEADERS,
18+
max_age=config.CORS_MAX_AGE,
19+
)
20+
21+
22+
# CORSMiddleware Configuration
23+
cors_middleware = EllarMiddleware(CORSMiddleware)

ellar/core/middleware/errors.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,33 @@
1-
import typing as t
2-
1+
from ellar.common import IExceptionMiddlewareService
32
from ellar.common.constants import SCOPE_RESPONSE_STARTED
4-
from ellar.common.interfaces import IExceptionHandler, IHostContextFactory
3+
from ellar.common.interfaces import IHostContextFactory
54
from ellar.common.responses import Response
65
from ellar.common.types import ASGIApp, TMessage, TReceive, TScope, TSend
7-
from ellar.core.connection.http import Request
6+
from ellar.core import Config
7+
from ellar.core.connection import Request
88
from ellar.di import EllarInjector
99
from starlette.middleware.errors import (
1010
ServerErrorMiddleware as BaseServerErrorMiddleware,
1111
)
1212
from starlette.requests import Request as StarletteRequest
1313
from starlette.responses import JSONResponse
1414

15+
from .middleware import EllarMiddleware
16+
1517

1618
class ServerErrorMiddleware(BaseServerErrorMiddleware):
1719
def __init__(
1820
self,
1921
app: ASGIApp,
20-
*,
21-
debug: bool,
22+
config: Config,
23+
exception_service: IExceptionMiddlewareService,
2224
injector: "EllarInjector",
23-
handler: t.Optional["IExceptionHandler"] = None,
2425
) -> None:
25-
_handler = None
26-
if handler:
27-
self._500_error_handler = handler
28-
_handler = self.error_handler
26+
self._500_error_handler = exception_service.get_500_error_handler()
2927

3028
super(ServerErrorMiddleware, self).__init__(
31-
debug=debug,
32-
handler=_handler, # type:ignore[arg-type]
29+
debug=config.DEBUG,
30+
handler=self.error_handler if self._500_error_handler else None,
3331
app=app,
3432
)
3533
self.injector = injector
@@ -69,3 +67,7 @@ def error_response(self, request: StarletteRequest, exc: Exception) -> Response:
6967
return JSONResponse(
7068
{"detail": "Internal server error", "status_code": 500}, status_code=500
7169
)
70+
71+
72+
# ServerErrorMiddleware Configuration
73+
server_error_middleware = EllarMiddleware(ServerErrorMiddleware)

ellar/core/middleware/exceptions.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,28 @@
11
import typing as t
22

3+
from ellar.common import APIException, IExceptionMiddlewareService
34
from ellar.common.interfaces import IHostContextFactory
45
from ellar.common.types import ASGIApp, TMessage, TReceive, TScope, TSend
5-
from ellar.core.context import current_injector
6-
from ellar.core.exceptions import ExceptionMiddlewareService
6+
from ellar.core.conf import Config
7+
from ellar.core.exceptions.service import ExceptionMiddlewareService
8+
from ellar.core.execution_context import current_injector
79
from starlette.exceptions import HTTPException
810

9-
if t.TYPE_CHECKING: # pragma: no cover
10-
pass
11+
from .middleware import EllarMiddleware
1112

1213

1314
class ExceptionMiddleware:
1415
def __init__(
1516
self,
1617
app: ASGIApp,
17-
exception_middleware_service: ExceptionMiddlewareService,
18-
debug: bool = False,
18+
exception_middleware_service: IExceptionMiddlewareService,
19+
config: Config,
1920
) -> None:
2021
self.app = app
21-
self.debug = debug
22-
self._exception_middleware_service = exception_middleware_service
22+
self.debug = config.DEBUG
23+
self._exception_middleware_service: ExceptionMiddlewareService = t.cast(
24+
ExceptionMiddlewareService, exception_middleware_service
25+
)
2326

2427
async def __call__(self, scope: TScope, receive: TReceive, send: TSend) -> None:
2528
if scope["type"] not in ("http", "websocket"):
@@ -40,7 +43,7 @@ async def sender(message: TMessage) -> None:
4043
except Exception as exc:
4144
handler = None
4245

43-
if isinstance(exc, HTTPException):
46+
if isinstance(exc, (HTTPException, APIException)):
4447
handler = self._exception_middleware_service.lookup_status_code_exception_handler(
4548
exc.status_code
4649
)
@@ -69,3 +72,7 @@ async def sender(message: TMessage) -> None:
6972
elif context.get_type() == "websocket":
7073
ws_context = context_factory.create_context(scope, receive, send)
7174
await handler.catch(ws_context, exc)
75+
76+
77+
# ExceptionMiddleware Configuration
78+
exception_middleware = EllarMiddleware(ExceptionMiddleware)

ellar/core/middleware/function.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import typing as t
22

3-
from ellar.common.interfaces import IHostContext, IHostContextFactory
3+
from ellar.common.interfaces import IHostContext
44
from ellar.common.types import ASGIApp, TReceive, TScope, TSend
5-
from ellar.core.connection import HTTPConnection
5+
from ellar.core.execution_context import current_connection
66
from starlette.responses import Response
77

88
AwaitableCallable = t.Callable[..., t.Awaitable]
@@ -57,18 +57,10 @@ async def __call__(self, scope: TScope, receive: TReceive, send: TSend) -> None:
5757
await self.app(scope, receive, send)
5858
return
5959

60-
connection = HTTPConnection(scope, receive)
61-
62-
if not connection.service_provider: # pragma: no cover
63-
raise Exception("Service Provider is required")
64-
65-
context_factory = connection.service_provider.get(IHostContextFactory)
66-
context = context_factory.create_context(scope, receive, send)
67-
6860
async def call_next() -> None:
6961
await self.app(scope, receive, send)
7062

71-
response = await self.dispatch_function(context, call_next)
63+
response = await self.dispatch_function(current_connection, call_next)
7264

7365
if response and isinstance(response, Response):
7466
await response(scope, receive, send)
Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,58 @@
1+
import inspect
12
import typing as t
23

34
from ellar.common.interfaces import IEllarMiddleware
45
from ellar.common.types import ASGIApp
5-
from ellar.core.context import current_injector
6-
from ellar.di import injectable
7-
from ellar.utils import build_init_kwargs
6+
from ellar.core.execution_context import current_injector
7+
from ellar.utils.importer import import_from_string
8+
from injector import _infer_injected_bindings
89
from starlette.middleware import Middleware
910

1011
T = t.TypeVar("T")
1112

1213

1314
class EllarMiddleware(Middleware, IEllarMiddleware):
14-
_provider_token: t.Optional[str]
15-
1615
@t.no_type_check
1716
def __init__(
1817
self,
19-
cls: t.Type[T],
20-
provider_token: t.Optional[str] = None,
18+
cls_or_import_string: t.Union[t.Type[T], str],
2119
**options: t.Any,
2220
) -> None:
23-
super().__init__(cls, **options)
24-
injectable()(self.cls)
25-
self.kwargs = build_init_kwargs(self.cls, self.kwargs)
26-
self._provider_token = provider_token
27-
28-
def _register_middleware(self) -> None:
29-
provider_token = self._provider_token
30-
if provider_token:
31-
module_data = next(
32-
current_injector.tree_manager.find_module(
33-
lambda data: data.name == provider_token
34-
)
35-
)
36-
37-
if module_data and module_data.is_ready:
38-
module_data.value.add_provider(self.cls, export=True)
39-
return
21+
super().__init__(cls_or_import_string, **options)
4022

41-
current_injector.tree_manager.get_root_module().add_provider(
42-
self.cls, export=True
43-
)
23+
def _ensure_class(self) -> None:
24+
if isinstance(self.cls, str):
25+
self.cls = import_from_string(self.cls)
4426

4527
def __iter__(self) -> t.Iterator[t.Any]:
28+
self._ensure_class()
4629
as_tuple = (self, self.args, self.kwargs)
4730
return iter(as_tuple)
4831

32+
def create_object(self, **init_kwargs: t.Any) -> t.Any:
33+
_result = dict(init_kwargs)
34+
35+
if hasattr(self.cls, "__init__"):
36+
spec = inspect.signature(self.cls.__init__)
37+
type_hints = _infer_injected_bindings(
38+
self.cls.__init__, only_explicit_bindings=False
39+
)
40+
41+
for k, annotation in type_hints.items():
42+
parameter = spec.parameters.get(k)
43+
if k in _result or (parameter and parameter.default is None):
44+
continue
45+
46+
_result[k] = current_injector.get(annotation)
47+
48+
return self.cls(**_result)
49+
4950
@t.no_type_check
5051
def __call__(self, app: ASGIApp, *args: t.Any, **kwargs: t.Any) -> T:
51-
self._register_middleware()
52-
kwargs.update(app=app)
52+
self._ensure_class()
53+
# kwargs.update(app=app)
5354
try:
54-
return current_injector.create_object(self.cls, additional_kwargs=kwargs)
55+
return self.create_object(**kwargs, app=app)
5556
except TypeError: # pragma: no cover
5657
# TODO: Fix future typing for lower python version.
57-
return self.cls(*args, **kwargs)
58+
return self.cls(*args, **kwargs, app=app)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from ellar.common.types import ASGIApp
2+
from ellar.core.conf import Config
3+
from starlette.middleware.trustedhost import (
4+
TrustedHostMiddleware as BaseTrustedHostMiddleware,
5+
)
6+
7+
from .middleware import EllarMiddleware
8+
9+
10+
class TrustedHostMiddleware(BaseTrustedHostMiddleware):
11+
def __init__(self, app: ASGIApp, config: Config) -> None:
12+
self.config = config
13+
14+
allowed_hosts = config.ALLOWED_HOSTS
15+
16+
if config.DEBUG and allowed_hosts != ["*"]:
17+
allowed_hosts = ["*"]
18+
19+
super().__init__(
20+
app, allowed_hosts=allowed_hosts, www_redirect=self.config.REDIRECT_HOST
21+
)
22+
23+
24+
# TrustedHostMiddleware Configuration
25+
trusted_host_middleware = EllarMiddleware(TrustedHostMiddleware)

ellar/core/middleware/versioning.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,31 @@
66
from ellar.core.versioning import BaseAPIVersioning, DefaultAPIVersioning
77
from starlette.types import ASGIApp
88

9+
from .middleware import EllarMiddleware
10+
911

1012
class RequestVersioningMiddleware:
11-
def __init__(self, app: ASGIApp, *, debug: bool, config: "Config") -> None:
13+
def __init__(self, app: ASGIApp, config: "Config") -> None:
1214
self.app = app
13-
self.debug = debug
1415
self.config = config
1516

1617
async def __call__(self, scope: TScope, receive: TReceive, send: TSend) -> None:
1718
if scope["type"] not in ["http", "websocket"]: # pragma: no cover
1819
await self.app(scope, receive, send)
1920
return
2021

22+
## setup Versioning Resolvers
2123
scheme = (
2224
t.cast(BaseAPIVersioning, self.config.VERSIONING_SCHEME)
2325
or DefaultAPIVersioning()
2426
)
2527

2628
version_scheme_resolver = scheme.get_version_resolver(scope)
2729
version_scheme_resolver.resolve()
30+
2831
scope[SCOPE_API_VERSIONING_RESOLVER] = version_scheme_resolver
2932
await self.app(scope, receive, send)
33+
34+
35+
# RequestVersioningMiddleware Configuration
36+
versioning_middleware = EllarMiddleware(RequestVersioningMiddleware)

0 commit comments

Comments
 (0)