@@ -16,52 +16,114 @@ def test_auth(authenticated_user: AuthenticatedUser = Depends(authenticate_user)
16
16
return f"Hello {name}"
17
17
"""
18
18
19
- from typing import Callable
20
19
from typing import Dict
21
20
from typing import Optional
21
+ from typing import Type
22
22
23
23
from fastapi import Depends
24
24
from fastapi import HTTPException
25
+ from fastapi import Request
26
+ from fastapi import status
27
+ from fastapi .openapi .models import OAuthFlows
28
+ from fastapi .security import HTTPAuthorizationCredentials
29
+ from fastapi .security import HTTPBearer
30
+ from fastapi .security import OAuth2
31
+ from fastapi .security import OAuth2AuthorizationCodeBearer
32
+ from fastapi .security import OAuth2PasswordBearer
25
33
from fastapi .security import OpenIdConnect
34
+ from fastapi .security import SecurityScopes
26
35
from jose import ExpiredSignatureError
27
36
from jose import JWTError
28
37
from jose import jwt
29
38
from jose .exceptions import JWTClaimsError
30
39
31
40
from fastapi_oidc import discovery
41
+ from fastapi_oidc .types import IDToken
32
42
33
43
34
- def get_auth (
35
- openid_connect_url : str ,
36
- issuer : Optional [str ] = None ,
37
- audience : Optional [str ] = None ,
38
- signature_cache_ttl : int = 3600 ,
39
- ) -> Callable [[str ], Dict ]:
40
- """Take configurations and returns the :func:`authenticate_user` function.
44
+ class AuthBearer (HTTPBearer ):
45
+ async def __call__ (self , request : Request ):
46
+ return await super ().__call__ (request )
41
47
42
- This function should only be invoked once at the beggining of your
43
- server code. The function it returns should be used to check user credentials.
44
48
45
- Args:
46
- openid_connect_url (URL): URL to the "well known" openid connect config
47
- e.g. https://dev-123456.okta.com/.well-known/openid-configuration
48
- issuer (URL): (Optional) The issuer URL from your auth server.
49
- audience (str): (Optional) The audience string configured by your auth server.
50
- signature_cache_ttl (int): How many seconds your app should cache the
51
- authorization server's public signatures.
49
+ class EmptyOAuth2 (OAuth2 ):
50
+ async def __call__ (self , request : Request ) -> Optional [str ]:
51
+ return None
52
52
53
- Returns:
54
- func: authenticate_user(auth_header: str) -> Dict
55
53
56
- Raises:
57
- Nothing intentional
58
- """
54
+ class Auth :
55
+ def __init__ (
56
+ self ,
57
+ openid_connect_url : str ,
58
+ issuer : Optional [str ] = None ,
59
+ audience : Optional [str ] = None ,
60
+ scopes : Dict [str , str ] = dict (),
61
+ auto_error : bool = True ,
62
+ signature_cache_ttl : int = 3600 ,
63
+ idtoken_model : Type = IDToken ,
64
+ ):
65
+ """Configure authentication and use method :func:`authenticate_user`
66
+ to check user credentials.
59
67
60
- oauth2_scheme = OpenIdConnect (openIdConnectUrl = openid_connect_url )
61
-
62
- discover = discovery .configure (cache_ttl = signature_cache_ttl )
68
+ Args:
69
+ openid_connect_url (URL): URL to the "well known" openid connect config
70
+ e.g. https://dev-123456.okta.com/.well-known/openid-configuration
71
+ issuer (URL): (Optional) The issuer URL from your auth server.
72
+ audience (str): (Optional) The audience string configured by your auth server.
73
+ scopes (Dict[str, str]): (Optional) A dictionary of scopes and their descriptions.
74
+ auto_error (bool): (Optional) If True, raise an HTTPException if the token is invalid.
75
+ signature_cache_ttl (int): How many seconds your app should cache the
76
+ authorization server's public signatures.
77
+ idtoken_model (Type): (Optional) The model to use for validating the ID Token.
78
+
79
+ Raises:
80
+ Nothing intentional
81
+ """
63
82
64
- def authenticate_user (auth_header : str = Depends (oauth2_scheme )) -> Dict :
83
+ self .openid_connect_url = openid_connect_url
84
+ self .issuer = issuer
85
+ self .audience = audience
86
+ self .auto_error = auto_error
87
+ self .idtoken_model = idtoken_model
88
+
89
+ self .discover = discovery .configure (cache_ttl = signature_cache_ttl )
90
+ oidc_discoveries = self .discover .auth_server (
91
+ openid_connect_url = self .openid_connect_url
92
+ )
93
+
94
+ self .oidc_scheme = OpenIdConnect (
95
+ openIdConnectUrl = openid_connect_url , auto_error = auto_error
96
+ )
97
+ self .password_scheme = OAuth2PasswordBearer (
98
+ tokenUrl = self .discover .token_url (oidc_discoveries ),
99
+ scopes = scopes ,
100
+ )
101
+ self .implicit_scheme = EmptyOAuth2 (
102
+ flows = OAuthFlows (
103
+ implicit = {
104
+ "authorizationUrl" : self .discover .authorization_url (
105
+ oidc_discoveries
106
+ ),
107
+ "scopes" : scopes ,
108
+ }
109
+ ),
110
+ scheme_name = "OAuth2ImplicitBearer" ,
111
+ auto_error = auto_error ,
112
+ )
113
+ self .authcode_scheme = OAuth2AuthorizationCodeBearer (
114
+ authorizationUrl = self .discover .authorization_url (oidc_discoveries ),
115
+ tokenUrl = self .discover .token_url (oidc_discoveries ),
116
+ # refreshUrl=self.discover.refresh_url(oidc_discoveries),
117
+ scopes = scopes ,
118
+ )
119
+
120
+ def authenticate_user (
121
+ self ,
122
+ security_scopes : SecurityScopes ,
123
+ authorization_credentials : Optional [HTTPAuthorizationCredentials ] = Depends (
124
+ AuthBearer (auto_error = False )
125
+ ),
126
+ ) -> Optional [IDToken ]:
65
127
"""Validate and parse OIDC ID token against issuer in config.
66
128
Note this function caches the signatures and algorithms of the issuing server
67
129
for signature_cache_ttl seconds.
@@ -76,27 +138,48 @@ def authenticate_user(auth_header: str = Depends(oauth2_scheme)) -> Dict:
76
138
raises:
77
139
HTTPException(status_code=401, detail=f"Unauthorized: {err}")
78
140
"""
79
- id_token = auth_header .split (" " )[- 1 ]
80
- OIDC_discoveries = discover .auth_server (openid_connect_url = openid_connect_url )
81
- key = discover .public_keys (OIDC_discoveries )
82
- algorithms = discover .signing_algos (OIDC_discoveries )
141
+
142
+ if authorization_credentials is None :
143
+ if self .auto_error :
144
+ raise HTTPException (
145
+ status .HTTP_401_UNAUTHORIZED , detail = "Missing bearer token"
146
+ )
147
+ else :
148
+ return None
149
+
150
+ oidc_discoveries = self .discover .auth_server (
151
+ openid_connect_url = self .openid_connect_url
152
+ )
153
+ key = self .discover .public_keys (oidc_discoveries )
154
+ algorithms = self .discover .signing_algos (oidc_discoveries )
83
155
84
156
try :
85
- return jwt .decode (
86
- id_token ,
157
+ id_token = jwt .decode (
158
+ authorization_credentials . credentials ,
87
159
key ,
88
160
algorithms ,
89
- audience = audience ,
90
- issuer = issuer ,
161
+ audience = self . audience ,
162
+ issuer = self . issuer ,
91
163
options = {
92
164
# Disabled at_hash check since we aren't using the access token
93
165
"verify_at_hash" : False ,
94
- "verify_iss" : issuer is not None ,
95
- "verify_aud" : audience is not None ,
166
+ "verify_iss" : self . issuer is not None ,
167
+ "verify_aud" : self . audience is not None ,
96
168
},
97
169
)
98
-
99
170
except (ExpiredSignatureError , JWTError , JWTClaimsError ) as err :
100
- raise HTTPException (status_code = 401 , detail = f"Unauthorized: { err } " )
101
-
102
- return authenticate_user
171
+ if self .auto_error :
172
+ raise HTTPException (status_code = 401 , detail = f"Unauthorized: { err } " )
173
+ else :
174
+ return None
175
+
176
+ if not set (security_scopes .scopes ).issubset (id_token ["scope" ].split (" " )):
177
+ if self .auto_error :
178
+ raise HTTPException (
179
+ status .HTTP_401_UNAUTHORIZED ,
180
+ detail = f"""Missing scope token, only have { id_token ["scopes" ]} """ ,
181
+ )
182
+ else :
183
+ return None
184
+
185
+ return self .idtoken_model (** id_token )
0 commit comments