Skip to content

Commit d46e5b8

Browse files
committed
Merge branch 'id-token-decoder' into dev
2 parents 8935521 + 3afa44a commit d46e5b8

File tree

5 files changed

+89
-15
lines changed

5 files changed

+89
-15
lines changed

msal/oauth2cli/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
__version__ = "0.1.0"
1+
__version__ = "0.2.0"
22

3-
from .oauth2 import Client
3+
from .oidc import Client
44
from .assertion import JwtSigner
55

msal/oauth2cli/oidc.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import json
2+
import base64
3+
import time
4+
5+
from . import oauth2
6+
7+
8+
def base64decode(raw):
9+
"""A helper can handle a padding-less raw input"""
10+
raw += '=' * (-len(raw) % 4) # https://stackoverflow.com/a/32517907/728675
11+
return base64.b64decode(raw).decode("utf-8")
12+
13+
14+
def decode_id_token(id_token, client_id=None, issuer=None, nonce=None, now=None):
15+
"""Decodes and validates an id_token and returns its claims as a dictionary.
16+
17+
ID token claims would at least contain: "iss", "sub", "aud", "exp", "iat",
18+
per `specs <https://openid.net/specs/openid-connect-core-1_0.html#IDToken>`_
19+
and it may contain other optional content such as "preferred_username",
20+
`maybe more <https://openid.net/specs/openid-connect-core-1_0.html#Claims>`_
21+
"""
22+
decoded = json.loads(base64decode(id_token.split('.')[1]))
23+
err = None # https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
24+
if issuer and issuer != decoded["iss"]:
25+
# https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationResponse
26+
err = ('2. The Issuer Identifier for the OpenID Provider, "%s", '
27+
"(which is typically obtained during Discovery), "
28+
"MUST exactly match the value of the iss (issuer) Claim.") % issuer
29+
if client_id:
30+
valid_aud = client_id in decoded["aud"] if isinstance(
31+
decoded["aud"], list) else client_id == decoded["aud"]
32+
if not valid_aud:
33+
err = "3. The aud (audience) Claim must contain this client's client_id."
34+
# Per specs:
35+
# 6. If the ID Token is received via direct communication between
36+
# the Client and the Token Endpoint (which it is in this flow),
37+
# the TLS server validation MAY be used to validate the issuer
38+
# in place of checking the token signature.
39+
if (now or time.time()) > decoded["exp"]:
40+
err = "9. The current time MUST be before the time represented by the exp Claim."
41+
if nonce and nonce != decoded.get("nonce"):
42+
err = ("11. Nonce must be the same value "
43+
"as the one that was sent in the Authentication Request")
44+
if err:
45+
raise RuntimeError("%s id_token was: %s" % (
46+
err, json.dumps(decoded, indent=2)))
47+
return decoded
48+
49+
50+
class Client(oauth2.Client):
51+
"""OpenID Connect is a layer on top of the OAuth2.
52+
53+
See its specs at https://openid.net/connect/
54+
"""
55+
56+
def decode_id_token(self, id_token, nonce=None):
57+
"""See :func:`~decode_id_token`."""
58+
return decode_id_token(
59+
id_token, nonce=nonce,
60+
client_id=self.client_id, issuer=self.configuration.get("issuer"))
61+
62+
def _obtain_token(self, grant_type, *args, **kwargs):
63+
"""The result will also contain one more key "id_token_claims",
64+
whose value will be a dictionary returned by :func:`~decode_id_token`.
65+
"""
66+
ret = super(Client, self)._obtain_token(grant_type, *args, **kwargs)
67+
if "id_token" in ret:
68+
ret["id_token_claims"] = self.decode_id_token(ret["id_token"])
69+
return ret
70+

msal/token_cache.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,16 @@
22
import threading
33
import time
44
import logging
5-
import base64
65

76
from .authority import canonicalize
7+
from .oauth2cli.oidc import base64decode, decode_id_token
88

99

1010
logger = logging.getLogger(__name__)
1111

1212
def is_subdict_of(small, big):
1313
return dict(big, **small) == big
1414

15-
def base64decode(raw): # This can handle a padding-less raw input
16-
raw += '=' * (-len(raw) % 4) # https://stackoverflow.com/a/32517907/728675
17-
return base64.b64decode(raw).decode("utf-8")
18-
1915

2016
class TokenCache(object):
2117
"""This is considered as a base class containing minimal cache behavior.
@@ -112,8 +108,8 @@ def add(self, event, now=None):
112108
}
113109

114110
if client_info:
115-
decoded_id_token = json.loads(
116-
base64decode(id_token.split('.')[1])) if id_token else {}
111+
decoded_id_token = decode_id_token(
112+
id_token, client_id=event["client_id"]) if id_token else {}
117113
key = self._build_account_key(home_account_id, environment, realm)
118114
self._cache.setdefault(self.CredentialType.ACCOUNT, {})[key] = {
119115
"home_account_id": home_account_id,

tests/test_application.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def setUp(self):
181181
"token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url),
182182
"response": TokenCacheTestCase.build_response(
183183
access_token="Siblings won't share AT. test_remove_account() will.",
184-
id_token=TokenCacheTestCase.build_id_token(),
184+
id_token=TokenCacheTestCase.build_id_token(aud=self.preexisting_family_app_id),
185185
uid=self.uid, utid=self.utid, refresh_token=self.frt, foci="1"),
186186
}) # The add(...) helper populates correct home_account_id for future searching
187187

tests/test_token_cache.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import base64
33
import json
4+
import time
45

56
from msal.token_cache import *
67
from tests import unittest
@@ -13,12 +14,17 @@
1314
class TokenCacheTestCase(unittest.TestCase):
1415

1516
@staticmethod
16-
def build_id_token(sub="sub", oid="oid", preferred_username="me", **kwargs):
17+
def build_id_token(
18+
iss="issuer", sub="subject", aud="my_client_id", exp=None, iat=None,
19+
preferred_username="me", **claims):
1720
return "header.%s.signature" % base64.b64encode(json.dumps(dict({
21+
"iss": iss,
1822
"sub": sub,
19-
"oid": oid,
23+
"aud": aud,
24+
"exp": exp or (time.time() + 100),
25+
"iat": iat or time.time(),
2026
"preferred_username": preferred_username,
21-
}, **kwargs)).encode()).decode('utf-8')
27+
}, **claims)).encode()).decode('utf-8')
2228

2329
@staticmethod
2430
def build_response( # simulate a response from AAD
@@ -54,9 +60,11 @@ def setUp(self):
5460
self.cache = TokenCache()
5561

5662
def testAdd(self):
57-
id_token = self.build_id_token(oid="object1234", preferred_username="John Doe")
63+
client_id = "my_client_id"
64+
id_token = self.build_id_token(
65+
oid="object1234", preferred_username="John Doe", aud=client_id)
5866
self.cache.add({
59-
"client_id": "my_client_id",
67+
"client_id": client_id,
6068
"scope": ["s2", "s1", "s3"], # Not in particular order
6169
"token_endpoint": "https://login.example.com/contoso/v2/token",
6270
"response": self.build_response(

0 commit comments

Comments
 (0)