Skip to content

Commit 1e07dc9

Browse files
committed
refactor: nicer way of implementing granular optional auth
1 parent bace82a commit 1e07dc9

File tree

3 files changed

+116
-65
lines changed

3 files changed

+116
-65
lines changed

example/main.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
app = FastAPI(
2323
title="Example",
2424
version="dev",
25-
dependencies=[Depends(auth.oidc_scheme)],
25+
dependencies=[Depends(auth.implicit_scheme)],
2626
)
2727

2828
# CORS errors instead of seeing internal exceptions
@@ -42,7 +42,7 @@ def redirect_to_docs():
4242

4343

4444
@app.get("/protected")
45-
def protected(id_token: Optional[KeycloakIDToken] = Security(auth.authenticate_user())):
45+
def protected(id_token: Optional[KeycloakIDToken] = Security(auth.required)):
4646
print(id_token)
4747
if id_token is None:
4848
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
@@ -52,9 +52,7 @@ def protected(id_token: Optional[KeycloakIDToken] = Security(auth.authenticate_u
5252

5353
@app.get("/mixed")
5454
def mixed(
55-
id_token: Optional[KeycloakIDToken] = Security(
56-
auth.authenticate_user(auto_error=False)
57-
),
55+
id_token: Optional[KeycloakIDToken] = Security(auth.optional),
5856
):
5957
if id_token is None:
6058
return dict(message="You are not authenticated")

fastapi_oidc/auth.py

Lines changed: 111 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def __init__(
5757
issuer: Optional[str] = None,
5858
audience: Optional[str] = None,
5959
scopes: Dict[str, str] = dict(),
60-
auto_error: bool = True,
6160
signature_cache_ttl: int = 3600,
6261
idtoken_model: Type = IDToken,
6362
):
@@ -70,7 +69,6 @@ def __init__(
7069
issuer (URL): (Optional) The issuer URL from your auth server.
7170
audience (str): (Optional) The audience string configured by your auth server.
7271
scopes (Dict[str, str]): (Optional) A dictionary of scopes and their descriptions.
73-
auto_error (bool): (Optional) If True, raise an HTTPException if the token is invalid.
7472
signature_cache_ttl (int): How many seconds your app should cache the
7573
authorization server's public signatures.
7674
idtoken_model (Type): (Optional) The model to use for validating the ID Token.
@@ -82,7 +80,6 @@ def __init__(
8280
self.openid_connect_url = openid_connect_url
8381
self.issuer = issuer
8482
self.audience = audience
85-
self.auto_error = auto_error
8683
self.idtoken_model = idtoken_model
8784

8885
self.discover = discovery.configure(cache_ttl=signature_cache_ttl)
@@ -119,74 +116,129 @@ def __init__(
119116
auto_error=False,
120117
)
121118

122-
def authenticate_user(self, auto_error=None):
119+
def required(
120+
self,
121+
security_scopes: SecurityScopes,
122+
authorization_credentials: Optional[HTTPAuthorizationCredentials] = Depends(
123+
HTTPBearer()
124+
),
125+
) -> Optional[IDToken]:
126+
"""Validate and parse OIDC ID token against issuer in config.
127+
Note this function caches the signatures and algorithms of the issuing
128+
server for signature_cache_ttl seconds.
129+
130+
Args:
131+
security_scopes (SecurityScopes): Security scopes
132+
auth_header (str): Base64 encoded OIDC Token. This is invoked
133+
behind the scenes by Depends.
134+
135+
Return:
136+
IDToken: Dictionary with IDToken information
137+
138+
raises:
139+
HTTPException(status_code=401, detail=f"Unauthorized: {err}")
140+
IDToken validation errors
141+
"""
142+
143+
return self.authenticate_user(
144+
security_scopes,
145+
authorization_credentials,
146+
auto_error=True,
147+
)
148+
149+
def optional(
150+
self,
151+
security_scopes: SecurityScopes,
152+
authorization_credentials: Optional[HTTPAuthorizationCredentials] = Depends(
153+
HTTPBearer(auto_error=False)
154+
),
155+
) -> Optional[IDToken]:
156+
"""Optionally validate and parse OIDC ID token against issuer in config.
157+
Will not raise if the user is not authenticated. Note this function
158+
caches the signatures and algorithms of the issuing server for
159+
signature_cache_ttl seconds.
160+
161+
Args:
162+
security_scopes (SecurityScopes): Security scopes
163+
auth_header (str): Base64 encoded OIDC Token. This is invoked
164+
behind the scenes by Depends.
165+
166+
Return:
167+
IDToken: Dictionary with IDToken information
168+
169+
raises:
170+
IDToken validation errors
171+
"""
172+
173+
return self.authenticate_user(
174+
security_scopes,
175+
authorization_credentials,
176+
auto_error=False,
177+
)
178+
179+
def authenticate_user(
180+
self,
181+
security_scopes: SecurityScopes,
182+
authorization_credentials: Optional[HTTPAuthorizationCredentials],
183+
auto_error: bool,
184+
) -> Optional[IDToken]:
123185
"""Validate and parse OIDC ID token against issuer in config.
124186
Note this function caches the signatures and algorithms of the issuing server
125187
for signature_cache_ttl seconds.
126188
127189
Args:
128-
auth_header (str): Base64 encoded OIDC Token. This is invoked behind the
129-
scenes by Depends.
190+
security_scopes (SecurityScopes): Security scopes
191+
auth_header (str): Base64 encoded OIDC Token
192+
auto_error (bool): If True, will raise an HTTPException if the user
193+
is not authenticated.
130194
131195
Return:
132196
IDToken: Dictionary with IDToken information
133197
134198
raises:
135199
HTTPException(status_code=401, detail=f"Unauthorized: {err}")
136200
"""
201+
if authorization_credentials is None:
202+
if auto_error:
203+
raise HTTPException(
204+
status.HTTP_401_UNAUTHORIZED, detail="Missing bearer token"
205+
)
206+
else:
207+
return None
137208

138-
if auto_error is None:
139-
auto_error = self.auto_error
140-
141-
def authenticate_user_(
142-
security_scopes: SecurityScopes,
143-
authorization_credentials: Optional[HTTPAuthorizationCredentials] = Depends(
144-
HTTPBearer(auto_error=auto_error)
145-
),
146-
) -> Optional[IDToken]:
147-
if authorization_credentials is None:
148-
if auto_error:
149-
raise HTTPException(
150-
status.HTTP_401_UNAUTHORIZED, detail="Missing bearer token"
151-
)
152-
else:
153-
return None
154-
155-
oidc_discoveries = self.discover.auth_server(
156-
openid_connect_url=self.openid_connect_url
209+
oidc_discoveries = self.discover.auth_server(
210+
openid_connect_url=self.openid_connect_url
211+
)
212+
key = self.discover.public_keys(oidc_discoveries)
213+
algorithms = self.discover.signing_algos(oidc_discoveries)
214+
215+
try:
216+
id_token = jwt.decode(
217+
authorization_credentials.credentials,
218+
key,
219+
algorithms,
220+
audience=self.audience,
221+
issuer=self.issuer,
222+
options={
223+
# Disabled at_hash check since we aren't using the access token
224+
"verify_at_hash": False,
225+
"verify_iss": self.issuer is not None,
226+
"verify_aud": self.audience is not None,
227+
},
157228
)
158-
key = self.discover.public_keys(oidc_discoveries)
159-
algorithms = self.discover.signing_algos(oidc_discoveries)
160-
161-
try:
162-
id_token = jwt.decode(
163-
authorization_credentials.credentials,
164-
key,
165-
algorithms,
166-
audience=self.audience,
167-
issuer=self.issuer,
168-
options={
169-
# Disabled at_hash check since we aren't using the access token
170-
"verify_at_hash": False,
171-
"verify_iss": self.issuer is not None,
172-
"verify_aud": self.audience is not None,
173-
},
229+
except (ExpiredSignatureError, JWTError, JWTClaimsError) as err:
230+
if auto_error:
231+
raise HTTPException(status_code=401, detail=f"Unauthorized: {err}")
232+
else:
233+
return None
234+
235+
if not set(security_scopes.scopes).issubset(id_token["scope"].split(" ")):
236+
if auto_error:
237+
raise HTTPException(
238+
status.HTTP_401_UNAUTHORIZED,
239+
detail=f"""Missing scope token, only have {id_token["scopes"]}""",
174240
)
175-
except (ExpiredSignatureError, JWTError, JWTClaimsError) as err:
176-
if auto_error:
177-
raise HTTPException(status_code=401, detail=f"Unauthorized: {err}")
178-
else:
179-
return None
180-
181-
if not set(security_scopes.scopes).issubset(id_token["scope"].split(" ")):
182-
if auto_error:
183-
raise HTTPException(
184-
status.HTTP_401_UNAUTHORIZED,
185-
detail=f"""Missing scope token, only have {id_token["scopes"]}""",
186-
)
187-
else:
188-
return None
189-
190-
return self.idtoken_model(**id_token)
191-
192-
return authenticate_user_
241+
else:
242+
return None
243+
244+
return self.idtoken_model(**id_token)

fastapi_oidc/types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List
2+
from typing import Union
23

34
from pydantic import BaseModel
45
from pydantic import Extra
@@ -24,7 +25,7 @@ class IDToken(BaseModel):
2425

2526
iss: str
2627
sub: str
27-
aud: List[str]
28+
aud: Union[str, List[str]]
2829
exp: int
2930
iat: int
3031

0 commit comments

Comments
 (0)