Skip to content

Commit 29d61cb

Browse files
Natureshadown2ygk
authored andcommitted
OpenID: Fix get_additional_claims API
* always propagate request * have get_additional_claims return a dict again * allow get_additional_claims to return plain data instead of callables
1 parent 2a288fd commit 29d61cb

File tree

5 files changed

+57
-27
lines changed

5 files changed

+57
-27
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2020
* #651 Batch expired token deletions in `cleartokens` management command
2121
* Added pt-BR translations.
2222
* #1070 Add a Celery task for clearing expired tokens, e.g. to be scheduled as a [periodic task](https://docs.celeryproject.org/en/stable/userguide/periodic-tasks.html)
23+
* #1069 OIDC: Re-introduce [additional claims](https://django-oauth-toolkit.readthedocs.io/en/latest/oidc.html#adding-claims-to-the-id-token) beyond `sub` to the id_token.
2324

2425
### Fixed
2526
* #1012 Return status for introspecting a nonexistent token from 401 to the correct value of 200 per [RFC 7662](https://datatracker.ietf.org/doc/html/rfc7662#section-2.2).

docs/oidc.rst

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -245,17 +245,22 @@ required claims, eg ``iss``, ``aud``, ``exp``, ``iat``, ``auth_time`` etc),
245245
and the ``sub`` claim will use the primary key of the user as the value.
246246
You'll probably want to customize this and add additional claims or change
247247
what is sent for the ``sub`` claim. To do so, you will need to add a method to
248-
our custom validator.
248+
our custom validator. It should return a dictionary mapping a claim name to
249+
either the claim data, or a callable that will be called with the request to
250+
produce the claim data.
249251
Standard claim ``sub`` is included by default, to remove it override ``get_claim_list``::
250252
class CustomOAuth2Validator(OAuth2Validator):
251-
def get_additional_claims(self):
253+
def get_additional_claims(self, request):
252254
def get_user_email(request):
253-
return request.user.get_full_name()
255+
return request.user.get_user_email()
254256

257+
claims = {}
255258
# Element name, callback to obtain data
256-
claims_list = [ ("email", get_sub_cod),
257-
("username", get_user_email) ]
258-
return claims_list
259+
claims["email"] = get_user_email
260+
# Element name, plain data returned
261+
claims["username"] = request.user.get_full_name()
262+
263+
return claims
259264

260265
.. note::
261266
This ``request`` object is not a ``django.http.Request`` object, but an

oauth2_provider/oauth2_validators.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -728,24 +728,24 @@ def _save_id_token(self, jti, request, expires, *args, **kwargs):
728728
def get_jwt_bearer_token(self, token, token_handler, request):
729729
return self.get_id_token(token, token_handler, request)
730730

731-
def get_claim_list(self):
732-
def get_sub_code(request):
733-
return str(request.user.id)
731+
def get_claim_dict(self, request):
732+
def get_sub_code(inner_request):
733+
return str(inner_request.user.id)
734734

735-
list = [ ("sub", get_sub_code) ]
735+
claims = {"sub": get_sub_code}
736736

737737
# https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims
738-
add = self.get_additional_claims()
739-
list.extend(add)
738+
add = self.get_additional_claims(request)
739+
claims.update(add)
740740

741-
return list
741+
return claims
742742

743743
def get_oidc_claims(self, token, token_handler, request):
744-
data = self.get_claim_list()
744+
data = self.get_claim_dict(request)
745745
claims = {}
746746

747-
for k, call in data:
748-
claims[k] = call(request)
747+
for k, v in data.items():
748+
claims[k] = v(request) if callable(v) else v
749749
return claims
750750

751751
def get_id_token_dictionary(self, token, token_handler, request):
@@ -898,5 +898,5 @@ def get_userinfo_claims(self, request):
898898
"""
899899
return self.get_oidc_claims(None, None, request)
900900

901-
def get_additional_claims(self):
902-
return []
901+
def get_additional_claims(self, request):
902+
return {}

oauth2_provider/views/oidc.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,7 @@ def get(self, request, *args, **kwargs):
4848

4949
validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS
5050
validator = validator_class()
51-
oidc_claims = []
52-
for el, _ in validator.get_claim_list():
53-
oidc_claims.append(el)
51+
oidc_claims = list(validator.get_claim_dict(request).keys())
5452

5553
data = {
5654
"issuer": issuer_url,

tests/test_oidc_views.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,39 @@ def claim_user_email(request):
156156

157157

158158
@pytest.mark.django_db
159-
def test_userinfo_endpoint_custom_claims(oidc_tokens, client, oauth2_settings):
159+
def test_userinfo_endpoint_custom_claims_callable(oidc_tokens, client, oauth2_settings):
160160
class CustomValidator(OAuth2Validator):
161-
def get_additional_claims(self):
162-
return [
163-
("username", claim_user_email),
164-
("email", claim_user_email),
165-
]
161+
def get_additional_claims(self, request):
162+
return {
163+
"username": claim_user_email,
164+
"email": claim_user_email,
165+
}
166+
167+
oidc_tokens.oauth2_settings.OAUTH2_VALIDATOR_CLASS = CustomValidator
168+
auth_header = "Bearer %s" % oidc_tokens.access_token
169+
rsp = client.get(
170+
reverse("oauth2_provider:user-info"),
171+
HTTP_AUTHORIZATION=auth_header,
172+
)
173+
data = rsp.json()
174+
assert "sub" in data
175+
assert data["sub"] == str(oidc_tokens.user.pk)
176+
177+
assert "username" in data
178+
assert data["username"] == EXAMPLE_EMAIL
179+
180+
assert "email" in data
181+
assert data["email"] == EXAMPLE_EMAIL
182+
183+
184+
@pytest.mark.django_db
185+
def test_userinfo_endpoint_custom_claims_plain(oidc_tokens, client, oauth2_settings):
186+
class CustomValidator(OAuth2Validator):
187+
def get_additional_claims(self, request):
188+
return {
189+
"username": EXAMPLE_EMAIL,
190+
"email": EXAMPLE_EMAIL,
191+
}
166192

167193
oidc_tokens.oauth2_settings.OAUTH2_VALIDATOR_CLASS = CustomValidator
168194
auth_header = "Bearer %s" % oidc_tokens.access_token

0 commit comments

Comments
 (0)