Skip to content

Commit 6519ff2

Browse files
committed
Merge branch 'oauth2' into id-token-decoder
2 parents 8935521 + 45bcb94 commit 6519ff2

File tree

2 files changed

+72
-2
lines changed

2 files changed

+72
-2
lines changed

msal/oauth2cli/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
__version__ = "0.1.0"
1+
__version__ = "0.2.0"
22

3-
from .oauth2 import Client
3+
from .oidc import Client
44
from .assertion import JwtSigner
55

msal/oauth2cli/oidc.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import json
2+
import base64
3+
import time
4+
5+
from . import oauth2
6+
7+
8+
def base64decode(raw):
9+
"""A helper can handle a padding-less raw input"""
10+
raw += '=' * (-len(raw) % 4) # https://stackoverflow.com/a/32517907/728675
11+
return base64.b64decode(raw).decode("utf-8")
12+
13+
14+
def decode_id_token(id_token, client_id=None, issuer=None, nonce=None, now=None):
15+
"""Decodes and validates an id_token and returns its claims as a dictionary.
16+
17+
ID token claims would at least contain: "iss", "sub", "aud", "exp", "iat",
18+
per `specs <https://openid.net/specs/openid-connect-core-1_0.html#IDToken>`_
19+
and it may contain other optional content such as "preferred_username",
20+
`maybe more <https://openid.net/specs/openid-connect-core-1_0.html#Claims>`_
21+
"""
22+
decoded = json.loads(base64decode(id_token.split('.')[1]))
23+
err = None # https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
24+
if issuer and issuer != decoded["iss"]:
25+
# https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationResponse
26+
err = ('2. The Issuer Identifier for the OpenID Provider, "%s", '
27+
"(which is typically obtained during Discovery), "
28+
"MUST exactly match the value of the iss (issuer) Claim.") % issuer
29+
if client_id:
30+
valid_aud = client_id in decoded["aud"] if isinstance(
31+
decoded["aud"], list) else client_id == decoded["aud"]
32+
if not valid_aud:
33+
err = "3. The aud (audience) Claim must contain this client's client_id."
34+
# Per specs:
35+
# 6. If the ID Token is received via direct communication between
36+
# the Client and the Token Endpoint (which it is in this flow),
37+
# the TLS server validation MAY be used to validate the issuer
38+
# in place of checking the token signature.
39+
if (now or time.time()) > decoded["exp"]:
40+
err = "9. The current time MUST be before the time represented by the exp Claim."
41+
if nonce and nonce != decoded.get("nonce"):
42+
err = ("11. Nonce must be the same value "
43+
"as the one that was sent in the Authentication Request")
44+
if err:
45+
raise RuntimeError("%s id_token was: %s" % (
46+
err, json.dumps(decoded, indent=2)))
47+
return decoded
48+
49+
50+
class Client(oauth2.Client):
51+
"""OpenID Connect is a layer on top of the OAuth2.
52+
53+
See its specs at https://openid.net/connect/
54+
"""
55+
56+
def decode_id_token(self, id_token, nonce=None):
57+
"""See :func:`~decode_id_token`."""
58+
return decode_id_token(
59+
id_token, nonce=nonce,
60+
client_id=self.client_id, issuer=self.configuration.get("issuer"))
61+
62+
def _obtain_token(self, grant_type, *args, **kwargs):
63+
"""The result will also contain one more key "id_token_claims",
64+
whose value will be a dictionary returned by :func:`~decode_id_token`.
65+
"""
66+
ret = super(Client, self)._obtain_token(grant_type, *args, **kwargs)
67+
if "id_token" in ret:
68+
ret["id_token_claims"] = self.decode_id_token(ret["id_token"])
69+
return ret
70+

0 commit comments

Comments
 (0)