1111import aiohttp_jinja2
1212from aiohttp import hdrs , web
1313
14- from ..authorization import Authorization
1514from ..const import CONF_TOKEN_LIFETIME
1615from .auth_client import AuthenticationClient
1716from .auth_database import AuthDatabase
@@ -36,11 +35,17 @@ def __init__(
3635 self .auth_database = auth_database
3736
3837 # Authorization Endpoint: obtain an authorization grant
39- self .app .router .add_get (path = "/oauth/authorize" , handler = self .authorization_endpoint_get )
40- self .app .router .add_post (path = "/oauth/authorize" , handler = self .authorization_endpoint_post )
38+ self .app .router .add_get (
39+ path = "/oauth/authorize" , handler = self .authorization_endpoint_get
40+ )
41+ self .app .router .add_post (
42+ path = "/oauth/authorize" , handler = self .authorization_endpoint_post
43+ )
4144
4245 # Token Endpoint: obtain an access token by authorization grant or refresh token
43- self .app .router .add_post (path = "/oauth/token" , handler = self .token_endpoint_handler )
46+ self .app .router .add_post (
47+ path = "/oauth/token" , handler = self .token_endpoint_handler
48+ )
4449
4550 self .app .router .add_post ("/revoke" , self .revoke_token_handler , name = "revoke" )
4651 self .app .router .add_get ("/protected" , self .protected_handler , name = "protected" )
@@ -73,7 +78,9 @@ async def protected_handler(self, request: web.Request) -> web.StreamResponse:
7378 return response
7479
7580 @aiohttp_jinja2 .template ("authorize.jinja2" )
76- async def authorization_endpoint_get (self , request : web .Request ) -> web .StreamResponse :
81+ async def authorization_endpoint_get (
82+ self , request : web .Request
83+ ) -> web .StreamResponse :
7784 """
7885 Validate the request to ensure that all required parameters are present and valid.
7986
@@ -105,7 +112,9 @@ async def authorization_endpoint_get(self, request: web.Request) -> web.StreamRe
105112
106113 # validate response_type
107114 if response_type != "code" :
108- _LOGGER .warning (f"The request is using an invalid response_type: { response_type } " )
115+ _LOGGER .warning (
116+ f"The request is using an invalid response_type: { response_type } "
117+ )
109118 data = """{
110119 "error":"unsupported_response_type",
111120 "error_description":"The request is using an invalid response_type"
@@ -120,7 +129,9 @@ async def authorization_endpoint_get(self, request: web.Request) -> web.StreamRe
120129 None ,
121130 )
122131 # validate if redirect_uri is in registered_auth_client
123- if not any (uri == redirect_uri for uri in registered_auth_client .redirect_uris ):
132+ if not any (
133+ uri == redirect_uri for uri in registered_auth_client .redirect_uris
134+ ):
124135 _LOGGER .error (f"redirect uri not found: { redirect_uri } " )
125136 data = """{
126137 "error":"unauthorized_client",
@@ -153,7 +164,9 @@ async def authorization_endpoint_get(self, request: web.Request) -> web.StreamRe
153164 # check if the requested scope is registered
154165 for requested_scope in requested_scopes :
155166 if requested_scope not in registered_scopes :
156- _LOGGER .error (f"The requested scope '{ requested_scope } ' is invalid, unknown, or malformed." )
167+ _LOGGER .error (
168+ f"The requested scope '{ requested_scope } ' is invalid, unknown, or malformed."
169+ )
157170 data = """{
158171 "error":"invalid_scope",
159172 "error_description":"The requested scope is invalid, unknown, or malformed."
@@ -227,7 +240,9 @@ async def authorization_endpoint_get(self, request: web.Request) -> web.StreamRe
227240 }"""
228241 return web .json_response (json .loads (data ))
229242
230- async def authorization_endpoint_post (self , request : web .Request ) -> web .StreamResponse :
243+ async def authorization_endpoint_post (
244+ self , request : web .Request
245+ ) -> web .StreamResponse :
231246 """
232247 Validate the resource owners credentials.
233248
@@ -252,7 +267,9 @@ async def authorization_endpoint_post(self, request: web.Request) -> web.StreamR
252267 if not any (client .client_id == client_id for client in self .auth_clients ):
253268 _LOGGER .warning (f"unknown client_id { client_id } " )
254269 if state is not None :
255- raise web .HTTPFound (f"{ redirect_uri } ?error=unauthorized_client&state={ state } " )
270+ raise web .HTTPFound (
271+ f"{ redirect_uri } ?error=unauthorized_client&state={ state } "
272+ )
256273 else :
257274 raise web .HTTPFound (f"{ redirect_uri } ?error=unauthorized_client" )
258275
@@ -265,41 +282,57 @@ async def authorization_endpoint_post(self, request: web.Request) -> web.StreamR
265282 if not any (uri == redirect_uri for uri in registered_auth_client .redirect_uris ):
266283 _LOGGER .error (f"invalid redirect_uri { redirect_uri } " )
267284 if state is not None :
268- raise web .HTTPFound (f"{ redirect_uri } ?error=unauthorized_client&state={ state } " )
285+ raise web .HTTPFound (
286+ f"{ redirect_uri } ?error=unauthorized_client&state={ state } "
287+ )
269288 else :
270289 raise web .HTTPFound (f"{ redirect_uri } ?error=unauthorized_client" )
271290
272- username = data ["uname " ]
291+ email = data ["email " ]
273292 password = data ["password" ]
274293
275294 # validate credentials
276- credentials_are_valid = await self .auth_database .check_credentials (username , password )
295+ credentials_are_valid = await self .auth_database .check_credentials (
296+ email , password
297+ )
277298
278299 if credentials_are_valid :
279300 # create an authorization code
280- authorization_code = self .auth_database .create_authorization_code (username , client_id , request .remote )
301+ authorization_code = self .auth_database .create_authorization_code (
302+ email , client_id , request .remote
303+ )
281304 _LOGGER .debug (f"authorization_code: { authorization_code } " )
282305 if authorization_code is None :
283306 _LOGGER .warning ("could not create auth code for client!" )
284307 error_reason = "access_denied"
285308 if state is not None :
286- raise web .HTTPFound (f"{ redirect_uri } ?error={ error_reason } &state={ state } " )
309+ raise web .HTTPFound (
310+ f"{ redirect_uri } ?error={ error_reason } &state={ state } "
311+ )
287312 else :
288313 raise web .HTTPFound (f"{ redirect_uri } ?error={ error_reason } " )
289314
290315 if state is not None :
291- _LOGGER .debug (f"HTTPFound: { redirect_uri } ?code={ authorization_code } &state={ state } " )
292- redirect_response = web .HTTPFound (f"{ redirect_uri } ?code={ authorization_code } &state={ state } " )
316+ _LOGGER .debug (
317+ f"HTTPFound: { redirect_uri } ?code={ authorization_code } &state={ state } "
318+ )
319+ redirect_response = web .HTTPFound (
320+ f"{ redirect_uri } ?code={ authorization_code } &state={ state } "
321+ )
293322 else :
294323 _LOGGER .debug (f"HTTPFound: { redirect_uri } ?code={ authorization_code } " )
295- redirect_response = web .HTTPFound (f"{ redirect_uri } ?code={ authorization_code } " )
324+ redirect_response = web .HTTPFound (
325+ f"{ redirect_uri } ?code={ authorization_code } "
326+ )
296327
297328 raise redirect_response
298329 else :
299330 error_reason = "access_denied"
300331 _LOGGER .warning (f"redirect with error { error_reason } " )
301332 if state is not None :
302- raise web .HTTPFound (f"{ redirect_uri } ?error={ error_reason } &state={ state } " )
333+ raise web .HTTPFound (
334+ f"{ redirect_uri } ?error={ error_reason } &state={ state } "
335+ )
303336 else :
304337 raise web .HTTPFound (f"{ redirect_uri } ?error={ error_reason } " )
305338
@@ -358,13 +391,17 @@ async def _handle_authorization_code_request(self, data) -> web.StreamResponse:
358391 return web .json_response (status = 400 , data = data )
359392 client_id = data ["client_id" ]
360393
361- client_code_valid = await self .auth_database .validate_authorization_code (code , client_id )
394+ client_code_valid = await self .auth_database .validate_authorization_code (
395+ code , client_id
396+ )
362397 if not client_code_valid :
363398 _LOGGER .error ("authorization_code invalid!" )
364399 payload = {"error" : "invalid_grant" }
365400 return web .json_response (status = 400 , data = payload )
366401
367- access_token , refresh_token = await self .auth_database .create_tokens (code , client_id )
402+ access_token , refresh_token = await self .auth_database .create_tokens (
403+ code , client_id
404+ )
368405
369406 payload = {
370407 "access_token" : access_token ,
@@ -374,7 +411,9 @@ async def _handle_authorization_code_request(self, data) -> web.StreamResponse:
374411 }
375412 return web .json_response (status = 200 , data = payload )
376413
377- async def _handle_refresh_token_request (self , request : web .Request , data ) -> web .StreamResponse :
414+ async def _handle_refresh_token_request (
415+ self , request : web .Request , data
416+ ) -> web .StreamResponse :
378417 """
379418 See Section 6: https://tools.ietf.org/html/rfc6749#section-6
380419 """
@@ -414,7 +453,9 @@ async def _handle_refresh_token_request(self, request: web.Request, data) -> web
414453 data = {"error" : "invalid_client" }
415454 return web .json_response (data )
416455
417- access_token , refresh_token = await self .auth_database .renew_tokens (client_id , refresh_token )
456+ access_token , refresh_token = await self .auth_database .renew_tokens (
457+ client_id , refresh_token
458+ )
418459
419460 if access_token is None :
420461 raise web .HTTPForbidden ()
@@ -436,11 +477,13 @@ def create_client(self):
436477 _LOGGER .info (f"generated client_secret: { client_secret } " )
437478
438479 async def check_authorized (self , request : web .Request ) -> Optional [str ]:
439- """Check if authorization header and returns username if valid"""
480+ """Check if authorization header and returns user ID if valid"""
440481
441482 if hdrs .AUTHORIZATION in request .headers :
442483 try :
443- auth_type , auth_val = request .headers .get (hdrs .AUTHORIZATION ).split (" " , 1 )
484+ auth_type , auth_val = request .headers .get (hdrs .AUTHORIZATION ).split (
485+ " " , 1
486+ )
444487 if not await self .auth_database .validate_access_token (auth_val ):
445488 raise web .HTTPForbidden ()
446489
0 commit comments