Skip to content

Commit ea265aa

Browse files
Use IDToken Subclasses (HarryMWinters#18)
- Add token_type argument to get_config. - Adds tests for subclassing functionality. - Update docs for subclassing IDToken
1 parent 04838e3 commit ea265aa

File tree

4 files changed

+85
-33
lines changed

4 files changed

+85
-33
lines changed

README.md

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818

1919
---
2020

21-
:warning: **See [this issue](https://github.com/HarryMWinters/fastapi-oidc/issues/1) for simple role-your-own example of checking OIDC tokens.**
21+
:warning: **See [this issue](https://github.com/HarryMWinters/fastapi-oidc/issues/1) for
22+
simple role-your-own example of checking OIDC tokens.**
2223

23-
Verify and decrypt 3rd party OIDC ID tokens to protect your [fastapi](https://github.com/tiangolo/fastapi) endpoints.
24+
Verify and decrypt 3rd party OIDC ID tokens to protect your
25+
[fastapi](https://github.com/tiangolo/fastapi) endpoints.
2426

2527
**Documentation:** [ReadTheDocs](https://fastapi-oidc.readthedocs.io/en/latest/)
2628

@@ -30,10 +32,13 @@ Verify and decrypt 3rd party OIDC ID tokens to protect your [fastapi](https://gi
3032

3133
`pip install fastapi-oidc`
3234

35+
## Usage
36+
3337
### Verify ID Tokens Issued by Third Party
3438

3539
This is great if you just want to use something like Okta or google to handle
36-
your auth. All you need to do is verify the token and then you can extract user ID info from it.
40+
your auth. All you need to do is verify the token and then you can extract user ID info
41+
from it.
3742

3843
```python3
3944
from fastapi import Depends
@@ -60,3 +65,24 @@ app = FastAPI()
6065
def protected(id_token: IDToken = Depends(authenticate_user)):
6166
return {"Hello": "World", "user_email": id_token.email}
6267
```
68+
69+
#### Using your own tokens
70+
71+
The IDToken class will accept any number of extra field but if you want to craft your
72+
own token class and validation that's accounted for too.
73+
74+
```python3
75+
class CustomIDToken(fastapi_oidc.IDToken):
76+
custom_field: str
77+
custom_default: float = 3.14
78+
79+
80+
authenticate_user: Callable = get_auth(**OIDC_config, token_type=CustomIDToken)
81+
82+
app = FastAPI()
83+
84+
85+
@app.get("/protected")
86+
def protected(id_token: CustomIDToken = Depends(authenticate_user)):
87+
return {"Hello": "World", "user_email": id_token.custom_default}
88+
```

fastapi_oidc/auth.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def test_auth(authenticated_user: AuthenticatedUser = Depends(authenticate_user)
1818

1919
from typing import Callable
2020
from typing import Optional
21+
from typing import Type
2122

2223
from fastapi import Depends
2324
from fastapi import HTTPException
@@ -28,6 +29,7 @@ def test_auth(authenticated_user: AuthenticatedUser = Depends(authenticate_user)
2829
from jose.exceptions import JWTClaimsError
2930

3031
from fastapi_oidc import discovery
32+
from fastapi_oidc.exceptions import TokenSpecificationError
3133
from fastapi_oidc.types import IDToken
3234

3335

@@ -38,31 +40,43 @@ def get_auth(
3840
base_authorization_server_uri: str,
3941
issuer: str,
4042
signature_cache_ttl: int,
43+
token_type: Type[IDToken] = IDToken,
4144
) -> Callable[[str], IDToken]:
4245
"""Take configurations and return the authenticate_user function.
4346
4447
This function should only be invoked once at the beggining of your
4548
server code. The function it returns should be used to check user credentials.
4649
4750
Args:
48-
client_id (str): This string is provided when you register with your resource server.
51+
client_id (str): This string is provided when you register with your resource
52+
server.
53+
base_authorization_server_uri(URL): Everything before /.wellknow in your auth
54+
server URL. I.E. https://dev-123456.okta.com
55+
issuer (URL): Same as base_authorization. This is used to generating OpenAPI3.0
56+
docs which is broken (in OpenAPI/FastAPI) right now.
57+
signature_cache_ttl (int): How many seconds your app should cache the
58+
authorization server's public signatures.
4959
audience (str): (Optional) The audience string configured by your auth server.
5060
If not set defaults to client_id
51-
base_authorization_server_uri(URL): Everything before /.wellknow in your auth server URL.
52-
I.E. https://dev-123456.okta.com
53-
issuer (URL): Same as base_authorization. This is used to generating OpenAPI3.0 docs which
54-
is broken (in OpenAPI/FastAPI) right now.
55-
signature_cache_ttl (int): How many seconds your app should cache the authorization
56-
server's public signatures.
61+
token_type (IDToken or subclass): (Optional) An optional class to be returned by
62+
the authenticate_user function.
5763
5864
5965
Returns:
60-
func: authenticate_user(auth_header: str)
66+
func: authenticate_user(auth_header: str) -> IDToken (or token_type)
6167
6268
Raises:
6369
Nothing intentional
6470
"""
65-
# As far as I can tell this does two things.
71+
72+
if not issubclass(token_type, IDToken):
73+
raise TokenSpecificationError(
74+
"Invalid argument for token_type. "
75+
"Token type must be a subclass of fastapi_oidc.type.IDToken. "
76+
f"Received {token_type=}"
77+
)
78+
79+
# As far as I can tell the oauth2_scheme does two things.
6680
# 1. Extracts and returns the Authorization header.
6781
# 2. Integrates with the OpenAPI3.0 doc generation in FastAPI.
6882
# This integration doesn't matter much now since OpenAPI
@@ -79,8 +93,8 @@ def authenticate_user(auth_header: str = Depends(oauth2_scheme)) -> IDToken:
7993
for signature_cache_ttl seconds.
8094
8195
Args:
82-
auth_header (str): Base64 encoded OIDC Token. This is invoked behind the scenes
83-
by Depends.
96+
auth_header (str): Base64 encoded OIDC Token. This is invoked behind the
97+
scenes by Depends.
8498
8599
Return:
86100
IDToken (types.IDToken):
@@ -103,27 +117,9 @@ def authenticate_user(auth_header: str = Depends(oauth2_scheme)) -> IDToken:
103117
# Disabled at_hash check since we aren't using the access token
104118
options={"verify_at_hash": False},
105119
)
106-
return IDToken.parse_obj(token)
120+
return token_type.parse_obj(token)
107121

108122
except (ExpiredSignatureError, JWTError, JWTClaimsError) as err:
109123
raise HTTPException(status_code=401, detail=f"Unauthorized: {err}")
110124

111125
return authenticate_user
112-
113-
114-
# This is a dummy method for sphinx docs. DO NOT User.
115-
# TODO Find a way to doc higher order functions w/ sphinx.
116-
def authenticate_user(auth_header: str) -> IDToken: # type: ignore
117-
"""
118-
Validate and parse OIDC ID token against issuer in config.
119-
Note this function caches the signatures and algorithms of the issuing server
120-
for signature_cache_ttl seconds.
121-
122-
Args:
123-
auth_header (str): Base64 encoded OIDC Token. This is invoked behind the scenes
124-
by Depends.
125-
126-
Return:
127-
IDToken (types.IDToken):
128-
"""
129-
pass

fastapi_oidc/exceptions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
class TokenSpecificationError(BaseException):
2+
pass

tests/test_auth.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import pytest
2+
13
from fastapi_oidc import auth
4+
from fastapi_oidc.exceptions import TokenSpecificationError
25
from fastapi_oidc.types import IDToken
36

47

@@ -40,3 +43,28 @@ def test__authenticate_user_no_aud(
4043

4144
assert id_token.email == test_email # nosec
4245
assert id_token.aud == no_audience_config["client_id"]
46+
47+
48+
def test__get_auth_raises_if_token_type_is_not_subclass_of_IDToken(no_audience_config):
49+
class BadToken:
50+
pass
51+
52+
with pytest.raises(TokenSpecificationError):
53+
auth.get_auth(**no_audience_config, token_type=BadToken)
54+
55+
56+
def test__authenticate_user_returns_custom_tokens(
57+
monkeypatch, mock_discovery, token_without_audience, no_audience_config
58+
):
59+
class CustomToken(IDToken):
60+
custom_field: str = "OnlySlightlyBent"
61+
62+
monkeypatch.setattr(auth.discovery, "configure", mock_discovery)
63+
64+
token = token_without_audience
65+
66+
authenticate_user = auth.get_auth(**no_audience_config, token_type=CustomToken)
67+
68+
custom_token: CustomToken = authenticate_user(auth_header=f"Bearer {token}")
69+
70+
assert custom_token.custom_field == "OnlySlightlyBent"

0 commit comments

Comments
 (0)