Skip to content

Commit d29c03b

Browse files
authored
feat: BI-6817 add dl_app_api_lib.AuthMiddleware (#1466)
1 parent ff64b6b commit d29c03b

File tree

20 files changed

+632
-12
lines changed

20 files changed

+632
-12
lines changed

lib/dl_app_api_base/dl_app_api_base/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,16 @@
55
HttpServerRequestContext,
66
HttpServerRequestContextDependencies,
77
HttpServerRequestContextManager,
8+
HttpServerSettings,
9+
)
10+
from .auth import (
11+
NoAuthChecker,
12+
NoAuthResult,
13+
OAuthChecker,
14+
OAuthCheckerSettings,
15+
OAuthResult,
16+
RequestAuthCheckerProtocol,
17+
RouteMatcher,
818
)
919
from .handlers import (
1020
BadRequestResponseSchema,
@@ -48,6 +58,7 @@
4858
"SubsystemReadinessAsyncCallback",
4959
"SubsystemReadinessCallback",
5060
"SubsystemReadinessSyncCallback",
61+
"HttpServerSettings",
5162
"HttpServerAppSettingsMixin",
5263
"HttpServerAppMixin",
5364
"HttpServerAppFactoryMixin",
@@ -71,4 +82,11 @@
7182
"HttpServerRequestContextDependencies",
7283
"HttpServerRequestContext",
7384
"HttpServerRequestContextManager",
85+
"RequestAuthCheckerProtocol",
86+
"OAuthChecker",
87+
"OAuthResult",
88+
"NoAuthChecker",
89+
"NoAuthResult",
90+
"RouteMatcher",
91+
"OAuthCheckerSettings",
7492
]

lib/dl_app_api_base/dl_app_api_base/app.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from typing import (
23
Generic,
34
TypeVar,
@@ -9,6 +10,7 @@
910
import pydantic
1011
from typing_extensions import override
1112

13+
import dl_app_api_base.auth as auth
1214
import dl_app_api_base.handlers as handlers
1315
import dl_app_api_base.middlewares as middlewares
1416
import dl_app_api_base.openapi as openapi
@@ -38,13 +40,15 @@ class HttpServerAppMixin(dl_app_base.BaseApp):
3840

3941
@attr.define(frozen=True, kw_only=True)
4042
class HttpServerRequestContextDependencies(
43+
auth.AuthRequestContextDependenciesMixin,
4144
request_context.BaseRequestContextDependencies,
4245
):
4346
...
4447

4548

4649
class HttpServerRequestContext(
4750
request_id.RequestIdRequestContextMixin,
51+
auth.AuthRequestContextMixin,
4852
request_context.BaseRequestContext,
4953
):
5054
_dependencies: HttpServerRequestContextDependencies
@@ -103,9 +107,30 @@ async def _get_request_context_manager(
103107
) -> HttpServerRequestContextManager:
104108
return HttpServerRequestContextManager(
105109
context_factory=HttpServerRequestContext.factory,
106-
dependencies=HttpServerRequestContextDependencies(),
110+
dependencies=HttpServerRequestContextDependencies(
111+
request_auth_checkers=await self._get_request_auth_checkers(),
112+
),
107113
)
108114

115+
@dl_app_base.singleton_class_method_result
116+
async def _get_request_auth_checkers(
117+
self,
118+
) -> list[auth.RequestAuthCheckerProtocol]:
119+
return [
120+
auth.NoAuthChecker(
121+
route_matchers=[
122+
auth.RouteMatcher(
123+
path_regex=re.compile(r"^/api/v1/health/.*$"),
124+
methods=frozenset(["GET"]),
125+
),
126+
auth.RouteMatcher(
127+
path_regex=re.compile(r"^/api/v1/docs/.*$"),
128+
methods=frozenset(["GET"]),
129+
),
130+
],
131+
),
132+
]
133+
109134
@dl_app_base.singleton_class_method_result
110135
async def _get_aiohttp_app_middlewares(
111136
self,
@@ -119,11 +144,15 @@ async def _get_aiohttp_app_middlewares(
119144
request_context_provider=request_context_manager,
120145
)
121146
error_handling_middleware = middlewares.ErrorHandlingMiddleware()
147+
auth_middleware = auth.AuthMiddleware(
148+
request_context_provider=request_context_manager,
149+
)
122150

123151
return [
124152
request_context_middlewares.process,
125153
logging_middleware.process,
126154
error_handling_middleware.process,
155+
auth_middleware.process,
127156
]
128157

129158
async def _setup_routes(self, app: aiohttp.web.Application) -> None:
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from .checkers import (
2+
BaseRequestAuthChecker,
3+
BaseRequestAuthResult,
4+
NoAuthChecker,
5+
NoAuthResult,
6+
OAuthChecker,
7+
OAuthCheckerSettings,
8+
OAuthResult,
9+
RequestAuthCheckerProtocol,
10+
)
11+
from .exc import (
12+
AuthError,
13+
AuthFailureError,
14+
NoApplicableAuthCheckersError,
15+
)
16+
from .middleware import AuthMiddleware
17+
from .models import RouteMatcher
18+
from .request_context import (
19+
AuthRequestContextDependenciesMixin,
20+
AuthRequestContextMixin,
21+
)
22+
23+
24+
__all__ = [
25+
"BaseRequestAuthChecker",
26+
"RequestAuthCheckerProtocol",
27+
"BaseRequestAuthResult",
28+
"NoAuthChecker",
29+
"NoAuthResult",
30+
"OAuthChecker",
31+
"OAuthCheckerSettings",
32+
"OAuthResult",
33+
"AuthRequestContextDependenciesMixin",
34+
"AuthRequestContextMixin",
35+
"AuthError",
36+
"NoApplicableAuthCheckersError",
37+
"AuthFailureError",
38+
"AuthMiddleware",
39+
"RouteMatcher",
40+
]
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from .base import (
2+
BaseRequestAuthChecker,
3+
BaseRequestAuthResult,
4+
RequestAuthCheckerProtocol,
5+
)
6+
from .no_auth import (
7+
NoAuthChecker,
8+
NoAuthResult,
9+
)
10+
from .oauth import (
11+
OAuthChecker,
12+
OAuthCheckerSettings,
13+
OAuthResult,
14+
)
15+
16+
17+
__all__ = [
18+
"BaseRequestAuthChecker",
19+
"RequestAuthCheckerProtocol",
20+
"BaseRequestAuthResult",
21+
"NoAuthChecker",
22+
"NoAuthResult",
23+
"OAuthChecker",
24+
"OAuthCheckerSettings",
25+
"OAuthResult",
26+
]
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import abc
2+
import logging
3+
from typing import (
4+
Protocol,
5+
Sequence,
6+
)
7+
8+
import aiohttp.web
9+
import attr
10+
11+
import dl_app_api_base.auth.models as auth_models
12+
13+
14+
LOGGER = logging.getLogger(__name__)
15+
16+
17+
@attr.define(frozen=True, kw_only=True)
18+
class BaseRequestAuthResult:
19+
...
20+
21+
22+
class RequestAuthCheckerProtocol(Protocol):
23+
async def is_applicable(self, request: aiohttp.web.Request) -> bool:
24+
...
25+
26+
async def check(self, request: aiohttp.web.Request) -> BaseRequestAuthResult:
27+
"""
28+
:raises: AuthFailureError if the authentication fails
29+
"""
30+
...
31+
32+
33+
@attr.define(frozen=True, kw_only=True)
34+
class BaseRequestAuthChecker(abc.ABC):
35+
_route_matchers: Sequence[auth_models.RouteMatcher]
36+
37+
async def is_applicable(self, request: aiohttp.web.Request) -> bool:
38+
return any(route_matcher.matches(request.path, request.method) for route_matcher in self._route_matchers)
39+
40+
@abc.abstractmethod
41+
async def check(self, request: aiohttp.web.Request) -> BaseRequestAuthResult:
42+
...
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import aiohttp.web
2+
import attr
3+
4+
import dl_app_api_base.auth.checkers.base as auth_checkers_base
5+
6+
7+
class NoAuthResult(auth_checkers_base.BaseRequestAuthResult):
8+
...
9+
10+
11+
@attr.define(frozen=True, kw_only=True)
12+
class NoAuthChecker(auth_checkers_base.BaseRequestAuthChecker):
13+
async def check(self, request: aiohttp.web.Request) -> NoAuthResult:
14+
return NoAuthResult()
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from typing import Sequence
2+
3+
import aiohttp.web
4+
import attr
5+
import pydantic
6+
from typing_extensions import Self
7+
8+
import dl_app_api_base.auth.checkers.base as auth_checkers_base
9+
import dl_app_api_base.auth.exc as auth_exc
10+
import dl_app_api_base.auth.models as auth_models
11+
import dl_settings
12+
13+
14+
@attr.define(frozen=True, kw_only=True)
15+
class OAuthResult(auth_checkers_base.BaseRequestAuthResult):
16+
client_id: str
17+
18+
19+
class OAuthUserSettings(dl_settings.BaseSettings):
20+
CLIENT_ID: str
21+
TOKEN: str = pydantic.Field(repr=False, alias="token")
22+
23+
24+
class OAuthCheckerSettings(dl_settings.BaseSettings):
25+
USERS: dict[str, OAuthUserSettings]
26+
HEADER_KEY: str = "Authorization"
27+
HEADER_PREFIX: str = "Bearer "
28+
29+
30+
@attr.define(frozen=True, kw_only=True)
31+
class OAuthChecker(auth_checkers_base.BaseRequestAuthChecker):
32+
_token_to_result_map: dict[str, OAuthResult]
33+
_header_key: str
34+
_header_prefix: str
35+
36+
@classmethod
37+
def from_settings(
38+
cls,
39+
settings: OAuthCheckerSettings,
40+
route_matchers: Sequence[auth_models.RouteMatcher],
41+
) -> Self:
42+
return cls(
43+
route_matchers=route_matchers,
44+
token_to_result_map={user.TOKEN: OAuthResult(client_id=user.CLIENT_ID) for user in settings.USERS.values()},
45+
header_key=settings.HEADER_KEY,
46+
header_prefix=settings.HEADER_PREFIX,
47+
)
48+
49+
async def is_applicable(self, request: aiohttp.web.Request) -> bool:
50+
if not await super().is_applicable(request):
51+
return False
52+
53+
authorization_header = request.headers.get(self._header_key, None)
54+
if authorization_header is None:
55+
return False
56+
57+
return authorization_header.startswith(self._header_prefix)
58+
59+
async def check(self, request: aiohttp.web.Request) -> OAuthResult:
60+
authorization_header = request.headers.get(self._header_key, None)
61+
if authorization_header is None:
62+
raise auth_exc.AuthFailureError("Authorization header is required")
63+
64+
if not authorization_header.startswith(self._header_prefix):
65+
raise auth_exc.AuthFailureError("Invalid authorization header format")
66+
67+
token = authorization_header.removeprefix(self._header_prefix)
68+
user = self._token_to_result_map.get(token, None)
69+
if user is None:
70+
raise auth_exc.AuthFailureError("Invalid token")
71+
72+
return user
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
class AuthError(Exception):
2+
...
3+
4+
5+
class NoApplicableAuthCheckersError(AuthError):
6+
...
7+
8+
9+
class AuthFailureError(AuthError):
10+
...
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import http
2+
import logging
3+
4+
import aiohttp.typedefs
5+
import aiohttp.web
6+
import attr
7+
8+
import dl_app_api_base.auth.exc as auth_exc
9+
import dl_app_api_base.auth.request_context as auth_request_context
10+
import dl_app_api_base.handlers as handlers
11+
import dl_app_api_base.request_context as request_context
12+
13+
14+
LOGGER = logging.getLogger(__name__)
15+
16+
17+
@attr.define(frozen=True, kw_only=True)
18+
class AuthMiddleware:
19+
_request_context_provider: request_context.RequestContextProviderProtocol[
20+
auth_request_context.AuthRequestContextMixin
21+
]
22+
23+
@aiohttp.web.middleware
24+
async def process(
25+
self,
26+
request: aiohttp.web.Request,
27+
handler: aiohttp.typedefs.Handler,
28+
) -> aiohttp.web.StreamResponse:
29+
context = self._request_context_provider.get()
30+
try:
31+
await context.get_auth_user()
32+
except auth_exc.AuthError:
33+
LOGGER.exception("Authentication failed")
34+
return handlers.Response.with_error(
35+
message="Unauthorized",
36+
code="UNAUTHORIZED",
37+
status=http.HTTPStatus.UNAUTHORIZED,
38+
)
39+
40+
return await handler(request)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import logging
2+
import re
3+
4+
import attr
5+
6+
7+
LOGGER = logging.getLogger(__name__)
8+
9+
_DEFAULT_ROUTE_MATCHER_METHODS = frozenset(["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
10+
11+
12+
@attr.define(frozen=True, kw_only=True)
13+
class RouteMatcher:
14+
path_regex: re.Pattern[str]
15+
methods: frozenset[str] = attr.ib(default=_DEFAULT_ROUTE_MATCHER_METHODS)
16+
17+
def matches(
18+
self,
19+
route: str,
20+
method: str,
21+
) -> bool:
22+
return self.path_regex.match(route) is not None and method in self.methods

0 commit comments

Comments
 (0)