Skip to content

Commit 1241140

Browse files
committed
GH-9: Add callback argument to middleware
1 parent 7606da6 commit 1241140

File tree

1 file changed

+34
-4
lines changed

1 file changed

+34
-4
lines changed

src/fastapi_oauth2/middleware.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from datetime import datetime
22
from datetime import timedelta
3+
from typing import Awaitable
4+
from typing import Callable
35
from typing import Dict
46
from typing import List
57
from typing import Optional
@@ -87,13 +89,20 @@ def identity(self) -> str:
8789

8890

8991
class OAuth2Backend(AuthenticationBackend):
90-
def __init__(self, config: OAuth2Config) -> None:
92+
"""Authentication backend for AuthenticationMiddleware."""
93+
94+
def __init__(
95+
self,
96+
config: OAuth2Config,
97+
callback: Callable[[User], Union[Awaitable[None], None]] = None,
98+
) -> None:
9199
Auth.set_http(config.allow_http)
92100
Auth.set_secret(config.jwt_secret)
93101
Auth.set_expires(config.jwt_expires)
94102
Auth.set_algorithm(config.jwt_algorithm)
95103
for client in config.clients:
96104
Auth.register_client(client)
105+
self.callback = callback
97106

98107
async def authenticate(self, request: Request) -> Optional[Tuple["Auth", "User"]]:
99108
authorization = request.headers.get(
@@ -106,18 +115,39 @@ async def authenticate(self, request: Request) -> Optional[Tuple["Auth", "User"]
106115
return Auth(), User()
107116

108117
user = Auth.jwt_decode(param)
109-
return Auth(user.pop("scope", [])), User(user)
118+
auth, user = Auth(user.pop("scope", [])), User(user)
119+
120+
# Call the callback function on authentication
121+
if callable(self.callback):
122+
coroutine = self.callback(user)
123+
if issubclass(type(coroutine), Awaitable):
124+
await coroutine
125+
return auth, user
110126

111127

112128
class OAuth2Middleware:
129+
"""Wrapper for the Starlette AuthenticationMiddleware."""
130+
113131
auth_middleware: AuthenticationMiddleware = None
114132

115-
def __init__(self, app: ASGIApp, config: Union[OAuth2Config, dict]) -> None:
133+
def __init__(
134+
self,
135+
app: ASGIApp,
136+
config: Union[OAuth2Config, dict],
137+
callback: Callable[[User], Union[Awaitable[None], None]] = None,
138+
**kwargs, # AuthenticationMiddleware kwargs
139+
) -> None:
140+
"""Initiates the middleware with the given configuration.
141+
142+
:param app: FastAPI application instance
143+
:param config: middleware configuration
144+
:param callback: callback function to be called after authentication
145+
"""
116146
if isinstance(config, dict):
117147
config = OAuth2Config(**config)
118148
elif not isinstance(config, OAuth2Config):
119149
raise TypeError("config is not a valid type")
120-
self.auth_middleware = AuthenticationMiddleware(app, OAuth2Backend(config))
150+
self.auth_middleware = AuthenticationMiddleware(app, backend=OAuth2Backend(config, callback), **kwargs)
121151

122152
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
123153
await self.auth_middleware(scope, receive, send)

0 commit comments

Comments
 (0)