1
+ from datetime import datetime
2
+ from datetime import timedelta
3
+ from typing import Dict
1
4
from typing import List
2
5
from typing import Optional
3
6
from typing import Tuple
4
7
from typing import Union
5
8
6
9
from fastapi .security .utils import get_authorization_scheme_param
10
+ from jose .jwt import decode as jwt_decode
11
+ from jose .jwt import encode as jwt_encode
7
12
from starlette .authentication import AuthenticationBackend
8
13
from starlette .middleware .authentication import AuthenticationMiddleware
9
14
from starlette .requests import Request
12
17
from starlette .types import Scope
13
18
from starlette .types import Send
14
19
20
+ from .client import OAuth2Client
15
21
from .config import OAuth2Config
16
- from .utils import jwt_decode
22
+ from .core import OAuth2Core
17
23
18
24
19
25
class Auth :
26
+ secret : str
27
+ expires : int
28
+ algorithm : str
20
29
scopes : List [str ]
30
+ clients : Dict [str , OAuth2Core ] = {}
21
31
22
32
def __init__ (self , scopes : Optional [List [str ]] = None ) -> None :
23
33
self .scopes = scopes or []
24
34
35
+ @classmethod
36
+ def set_secret (cls , secret : str ) -> None :
37
+ cls .secret = secret
38
+
39
+ @classmethod
40
+ def set_expires (cls , expires : int ) -> None :
41
+ cls .expires = expires
42
+
43
+ @classmethod
44
+ def set_algorithm (cls , algorithm : str ) -> None :
45
+ cls .algorithm = algorithm
46
+
47
+ @classmethod
48
+ def register_client (cls , client : OAuth2Client ) -> None :
49
+ cls .clients [client .backend .name ] = OAuth2Core (client )
50
+
51
+ @classmethod
52
+ def jwt_encode (cls , data : dict ) -> str :
53
+ return jwt_encode (data , cls .secret , algorithm = cls .algorithm )
54
+
55
+ @classmethod
56
+ def jwt_decode (cls , token : str ) -> dict :
57
+ return jwt_decode (token , cls .secret , algorithms = [cls .algorithm ])
58
+
59
+ @classmethod
60
+ def jwt_create (cls , token_data : dict ) -> str :
61
+ expire = datetime .utcnow () + timedelta (minutes = cls .expires )
62
+ return cls .jwt_encode ({** token_data , "exp" : expire })
63
+
25
64
26
65
class User (dict ):
27
66
is_authenticated : bool
@@ -32,30 +71,34 @@ def __init__(self, seq: Optional[dict] = None, **kwargs) -> None:
32
71
33
72
34
73
class OAuth2Backend (AuthenticationBackend ):
74
+ def __init__ (self , config : OAuth2Config ) -> None :
75
+ Auth .set_secret (config .jwt_secret )
76
+ Auth .set_expires (config .jwt_expires )
77
+ Auth .set_algorithm (config .jwt_algorithm )
78
+ OAuth2Core .allow_http = config .allow_http
79
+ for client in config .clients :
80
+ Auth .register_client (client )
81
+
35
82
async def authenticate (self , request : Request ) -> Optional [Tuple ["Auth" , "User" ]]:
36
83
authorization = request .cookies .get ("Authorization" )
37
84
scheme , param = get_authorization_scheme_param (authorization )
38
85
39
86
if not scheme or not param :
40
87
return Auth (), User ()
41
88
42
- user = jwt_decode (param )
43
- scopes = user .pop ("scope" )
44
- return Auth (scopes ), User (user )
89
+ user = Auth .jwt_decode (param )
90
+ return Auth (user .pop ("scope" )), User (user )
45
91
46
92
47
93
class OAuth2Middleware :
48
- config : OAuth2Config
49
- auth_middleware : AuthenticationMiddleware
94
+ auth_middleware : AuthenticationMiddleware = None
50
95
51
96
def __init__ (self , app : ASGIApp , config : Union [OAuth2Config , dict ]) -> None :
52
- if isinstance (config , OAuth2Config ):
53
- self .config = config
54
- elif isinstance (config , dict ):
55
- self .config = OAuth2Config (** config )
56
- else :
97
+ if isinstance (config , dict ):
98
+ config = OAuth2Config (** config )
99
+ elif not isinstance (config , OAuth2Config ):
57
100
raise TypeError ("config is not a valid type" )
58
- self .auth_middleware = AuthenticationMiddleware (app , OAuth2Backend ())
101
+ self .auth_middleware = AuthenticationMiddleware (app , OAuth2Backend (config ))
59
102
60
103
async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
61
104
await self .auth_middleware (scope , receive , send )
0 commit comments