1
1
from datetime import datetime
2
2
from datetime import timedelta
3
+ from typing import Awaitable
4
+ from typing import Callable
3
5
from typing import Dict
4
6
from typing import List
5
7
from typing import Optional
@@ -87,13 +89,20 @@ def identity(self) -> str:
87
89
88
90
89
91
class OAuth2Backend (AuthenticationBackend ):
90
- def __init__ (self , config : OAuth2Config ) -> None :
92
+ """Authentication backend for AuthenticationMiddleware."""
93
+
94
+ def __init__ (
95
+ self ,
96
+ config : OAuth2Config ,
97
+ callback : Callable [[User ], Union [Awaitable [None ], None ]] = None ,
98
+ ) -> None :
91
99
Auth .set_http (config .allow_http )
92
100
Auth .set_secret (config .jwt_secret )
93
101
Auth .set_expires (config .jwt_expires )
94
102
Auth .set_algorithm (config .jwt_algorithm )
95
103
for client in config .clients :
96
104
Auth .register_client (client )
105
+ self .callback = callback
97
106
98
107
async def authenticate (self , request : Request ) -> Optional [Tuple ["Auth" , "User" ]]:
99
108
authorization = request .headers .get (
@@ -106,18 +115,39 @@ async def authenticate(self, request: Request) -> Optional[Tuple["Auth", "User"]
106
115
return Auth (), User ()
107
116
108
117
user = Auth .jwt_decode (param )
109
- return Auth (user .pop ("scope" , [])), User (user )
118
+ auth , user = Auth (user .pop ("scope" , [])), User (user )
119
+
120
+ # Call the callback function on authentication
121
+ if callable (self .callback ):
122
+ coroutine = self .callback (user )
123
+ if issubclass (type (coroutine ), Awaitable ):
124
+ await coroutine
125
+ return auth , user
110
126
111
127
112
128
class OAuth2Middleware :
129
+ """Wrapper for the Starlette AuthenticationMiddleware."""
130
+
113
131
auth_middleware : AuthenticationMiddleware = None
114
132
115
- def __init__ (self , app : ASGIApp , config : Union [OAuth2Config , dict ]) -> None :
133
+ def __init__ (
134
+ self ,
135
+ app : ASGIApp ,
136
+ config : Union [OAuth2Config , dict ],
137
+ callback : Callable [[User ], Union [Awaitable [None ], None ]] = None ,
138
+ ** kwargs , # AuthenticationMiddleware kwargs
139
+ ) -> None :
140
+ """Initiates the middleware with the given configuration.
141
+
142
+ :param app: FastAPI application instance
143
+ :param config: middleware configuration
144
+ :param callback: callback function to be called after authentication
145
+ """
116
146
if isinstance (config , dict ):
117
147
config = OAuth2Config (** config )
118
148
elif not isinstance (config , OAuth2Config ):
119
149
raise TypeError ("config is not a valid type" )
120
- self .auth_middleware = AuthenticationMiddleware (app , OAuth2Backend (config ) )
150
+ self .auth_middleware = AuthenticationMiddleware (app , backend = OAuth2Backend (config , callback ), ** kwargs )
121
151
122
152
async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
123
153
await self .auth_middleware (scope , receive , send )
0 commit comments