Skip to content

Commit 24424c0

Browse files
committed
Merge branch 'target-in-string' into dev
2 parents 1d1e1f5 + 9110eca commit 24424c0

File tree

6 files changed

+170
-33
lines changed

6 files changed

+170
-33
lines changed

msal/application.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,12 +292,14 @@ def acquire_token_silent(
292292
})
293293
now = time.time()
294294
for entry in matches:
295-
if entry["expires_on"] - now < 5*60:
295+
expires_in = int(entry["expires_on"]) - now
296+
if expires_in < 5*60:
296297
continue # Removal is not necessary, it will be overwritten
298+
logger.debug("Cache hit an AT")
297299
return { # Mimic a real response
298300
"access_token": entry["secret"],
299301
"token_type": "Bearer",
300-
"expires_in": entry["expires_on"] - now,
302+
"expires_in": int(expires_in), # OAuth2 specs defines it as int
301303
}
302304

303305
matches = self.token_cache.find(
@@ -311,6 +313,7 @@ def acquire_token_silent(
311313
})
312314
client = self._build_client(self.client_credential, the_authority)
313315
for entry in matches:
316+
logger.debug("Cache hit an RT")
314317
response = client.obtain_token_by_refresh_token(
315318
entry, rt_getter=lambda token_item: token_item["secret"],
316319
scope=decorate_scope(scopes, self.client_id))

msal/token_cache.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,19 @@ def __init__(self):
3838
def find(self, credential_type, target=None, query=None):
3939
target = target or []
4040
assert isinstance(target, list), "Invalid parameter type"
41+
target_set = set(target)
4142
with self._lock:
43+
# Since the target inside token cache key is (per schema) unsorted,
44+
# there is no point to attempt an O(1) key-value search here.
45+
# So we always do an O(n) in-memory search.
4246
return [entry
4347
for entry in self._cache.get(credential_type, {}).values()
4448
if is_subdict_of(query or {}, entry)
45-
and set(target) <= set(entry.get("target", []))]
49+
and (target_set <= set(entry.get("target", "").split())
50+
if target else True)
51+
]
4652

47-
def add(self, event):
53+
def add(self, event, now=None):
4854
# type: (dict) -> None
4955
# event typically contains: client_id, scope, token_endpoint,
5056
# resposne, params, data, grant_type
@@ -56,9 +62,9 @@ def add(self, event):
5662
default=str, # A workaround when assertion is in bytes in Python 3
5763
))
5864
response = event.get("response", {})
59-
access_token = response.get("access_token", {})
60-
refresh_token = response.get("refresh_token", {})
61-
id_token = response.get("id_token", {})
65+
access_token = response.get("access_token")
66+
refresh_token = response.get("refresh_token")
67+
id_token = response.get("id_token")
6268
client_info = {}
6369
home_account_id = None
6470
if "client_info" in response:
@@ -67,6 +73,7 @@ def add(self, event):
6773
environment = realm = None
6874
if "token_endpoint" in event:
6975
_, environment, realm = canonicalize(event["token_endpoint"])
76+
target = ' '.join(event.get("scope", [])) # Per schema, we don't sort it
7077

7178
with self._lock:
7279

@@ -77,20 +84,22 @@ def add(self, event):
7784
self.CredentialType.ACCESS_TOKEN,
7885
event.get("client_id", ""),
7986
realm or "",
80-
' '.join(sorted(event.get("scope", []))),
87+
target,
8188
]).lower()
82-
now = time.time()
89+
now = time.time() if now is None else now
90+
expires_in = response.get("expires_in", 3599)
8391
self._cache.setdefault(self.CredentialType.ACCESS_TOKEN, {})[key] = {
8492
"credential_type": self.CredentialType.ACCESS_TOKEN,
8593
"secret": access_token,
8694
"home_account_id": home_account_id,
8795
"environment": environment,
8896
"client_id": event.get("client_id"),
89-
"target": event.get("scope"),
97+
"target": target,
9098
"realm": realm,
91-
"cached_at": now,
92-
"expires_on": now + response.get("expires_in", 3599),
93-
"extended_expires_on": now + response.get("ext_expires_in", 0),
99+
"cached_at": str(int(now)), # Schema defines it as a string
100+
"expires_on": str(int(now + expires_in)), # Same here
101+
"extended_expires_on": str(int( # Same here
102+
now + response.get("ext_expires_in", expires_in))),
94103
}
95104

96105
if client_info:
@@ -108,7 +117,10 @@ def add(self, event):
108117
"local_account_id": decoded_id_token.get(
109118
"oid", decoded_id_token.get("sub")),
110119
"username": decoded_id_token.get("preferred_username"),
111-
"authority_type": "AAD", # Always AAD?
120+
"authority_type":
121+
"ADFS" if realm == "adfs"
122+
else "MSSTS", # MSSTS means AAD v2 for both AAD & MSA
123+
# "client_info": response.get("client_info"), # Optional
112124
}
113125

114126
if id_token:
@@ -118,6 +130,7 @@ def add(self, event):
118130
self.CredentialType.ID_TOKEN,
119131
event.get("client_id", ""),
120132
realm or "",
133+
"" # Albeit irrelevant, schema requires an empty scope here
121134
]).lower()
122135
self._cache.setdefault(self.CredentialType.ID_TOKEN, {})[key] = {
123136
"credential_type": self.CredentialType.ID_TOKEN,
@@ -132,16 +145,14 @@ def add(self, event):
132145
if refresh_token:
133146
key = self._build_rt_key(
134147
home_account_id, environment,
135-
event.get("client_id", ""), event.get("scope", []))
148+
event.get("client_id", ""), target)
136149
rt = {
137150
"credential_type": self.CredentialType.REFRESH_TOKEN,
138151
"secret": refresh_token,
139152
"home_account_id": home_account_id,
140153
"environment": environment,
141154
"client_id": event.get("client_id"),
142-
# Fields below are considered optional
143-
"target": event.get("scope"),
144-
"client_info": response.get("client_info"),
155+
"target": target, # Optional per schema though
145156
}
146157
if "foci" in response:
147158
rt["family_id"] = response["foci"]
@@ -158,7 +169,7 @@ def _build_rt_key(
158169
cls.CredentialType.REFRESH_TOKEN,
159170
client_id or "",
160171
"", # RT is cross-tenant in AAD
161-
' '.join(sorted(target or [])),
172+
target or "", # raw value could be None if deserialized from other SDK
162173
]).lower()
163174

164175
def remove_rt(self, rt_item):
@@ -169,7 +180,8 @@ def remove_rt(self, rt_item):
169180
def update_rt(self, rt_item, new_rt):
170181
key = self._build_rt_key(**rt_item)
171182
with self._lock:
172-
rt = self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {})[key]
183+
RTs = self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {})
184+
rt = RTs.get(key, {}) # key usually exists, but we'll survive its absence
173185
rt["secret"] = new_rt
174186

175187

@@ -195,8 +207,8 @@ class SerializableTokenCache(TokenCache):
195207
Indicates whether the cache state has changed since last
196208
:func:`~serialize` or :func:`~deserialize` call.
197209
"""
198-
def add(self, event):
199-
super(SerializableTokenCache, self).add(event)
210+
def add(self, event, **kwargs):
211+
super(SerializableTokenCache, self).add(event, **kwargs)
200212
self.has_state_changed = True
201213

202214
def remove_rt(self, rt_item):
@@ -219,5 +231,5 @@ def serialize(self):
219231
"""Serialize the current cache state into a string."""
220232
with self._lock:
221233
self.has_state_changed = False
222-
return json.dumps(self._cache)
234+
return json.dumps(self._cache, indent=4)
223235

sample/client_credential_sample.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
config["client_id"], authority=config["authority"],
3131
client_credential=config["secret"],
3232
# token_cache=... # Default cache is in memory only.
33-
# See SerializableTokenCache for more details.
33+
# You can learn how to use SerializableTokenCache from
34+
# https://msal-python.rtfd.io/en/latest/#msal.SerializableTokenCache
3435
)
3536

3637
# The pattern to acquire a token looks like this.
@@ -42,7 +43,7 @@
4243
result = app.acquire_token_silent(config["scope"], account=None)
4344

4445
if not result:
45-
# So no suitable token exists in cache. Let's get a new one from AAD.
46+
logging.info("No suitable token exists in cache. Let's get a new one from AAD.")
4647
result = app.acquire_token_for_client(scopes=config["scope"])
4748

4849
if "access_token" in result:

sample/device_flow_sample.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
{
55
"authority": "https://login.microsoftonline.com/organizations",
66
"client_id": "your_client_id",
7-
"scope": ["user.read"]
7+
"scope": ["User.Read"]
88
}
99
1010
You can then run this sample with a JSON configuration file:
@@ -28,7 +28,8 @@
2828
app = msal.PublicClientApplication(
2929
config["client_id"], authority=config["authority"],
3030
# token_cache=... # Default cache is in memory only.
31-
# See SerializableTokenCache for more details.
31+
# You can learn how to use SerializableTokenCache from
32+
# https://msal-python.rtfd.io/en/latest/#msal.SerializableTokenCache
3233
)
3334

3435
# The pattern to acquire a token looks like this.
@@ -39,7 +40,7 @@
3940
# We now check the cache to see if we have some end users signed in before.
4041
accounts = app.get_accounts()
4142
if accounts:
42-
# If so, you could then somehow display these accounts and let end user choose
43+
logging.info("Account(s) exists in cache, probably with token too. Let's try.")
4344
print("Pick the account you want to use to proceed:")
4445
for a in accounts:
4546
print(a["username"])
@@ -49,7 +50,7 @@
4950
result = app.acquire_token_silent(config["scope"], account=chosen)
5051

5152
if not result:
52-
# So no suitable token exists in cache. Let's get a new one from AAD.
53+
logging.info("No suitable token exists in cache. Let's get a new one from AAD.")
5354
flow = app.initiate_device_flow(scopes=config["scope"])
5455
print(flow["message"])
5556
# Ideally you should wait here, in order to save some unnecessary polling

sample/username_password_sample.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"authority": "https://login.microsoftonline.com/organizations",
66
"client_id": "your_client_id",
77
"username": "your_username@your_tenant.com",
8-
"scope": ["user.read"],
8+
"scope": ["User.Read"],
99
"password": "This is a sample only. You better NOT persist your password."
1010
}
1111
@@ -30,7 +30,8 @@
3030
app = msal.PublicClientApplication(
3131
config["client_id"], authority=config["authority"],
3232
# token_cache=... # Default cache is in memory only.
33-
# See SerializableTokenCache for more details.
33+
# You can learn how to use SerializableTokenCache from
34+
# https://msal-python.rtfd.io/en/latest/#msal.SerializableTokenCache
3435
)
3536

3637
# The pattern to acquire a token looks like this.
@@ -39,11 +40,11 @@
3940
# Firstly, check the cache to see if this end user has signed in before
4041
accounts = app.get_accounts(username=config["username"])
4142
if accounts:
42-
# It means the account(s) exists in cache, probably with token too. Let's try.
43+
logging.info("Account(s) exists in cache, probably with token too. Let's try.")
4344
result = app.acquire_token_silent(config["scope"], account=accounts[0])
4445

4546
if not result:
46-
# So no suitable token exists in cache. Let's get a new one from AAD.
47+
logging.info("No suitable token exists in cache. Let's get a new one from AAD.")
4748
result = app.acquire_token_by_username_password(
4849
config["username"], config["password"], scopes=config["scope"])
4950

tests/test_token_cache.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import logging
2+
import base64
3+
import json
4+
5+
from msal.token_cache import *
6+
from tests import unittest
7+
8+
9+
logger = logging.getLogger(__name__)
10+
logging.basicConfig(level=logging.DEBUG)
11+
12+
13+
class TokenCacheTestCase(unittest.TestCase):
14+
15+
def setUp(self):
16+
self.cache = TokenCache()
17+
18+
def testAdd(self):
19+
client_info = base64.b64encode(b'''
20+
{"uid": "uid", "utid": "utid"}
21+
''').decode('utf-8')
22+
id_token = "header.%s.signature" % base64.b64encode(b'''{
23+
"sub": "subject",
24+
"oid": "object1234",
25+
"preferred_username": "John Doe"
26+
}''').decode('utf-8')
27+
self.cache.add({
28+
"client_id": "my_client_id",
29+
"scope": ["s2", "s1", "s3"], # Not in particular order
30+
"token_endpoint": "https://login.example.com/contoso/v2/token",
31+
"response": {
32+
"access_token": "an access token",
33+
"token_type": "some type",
34+
"expires_in": 3600,
35+
"refresh_token": "a refresh token",
36+
"client_info": client_info,
37+
"id_token": id_token,
38+
},
39+
}, now=1000)
40+
self.assertEqual(
41+
{
42+
'cached_at': "1000",
43+
'client_id': 'my_client_id',
44+
'credential_type': 'AccessToken',
45+
'environment': 'login.example.com',
46+
'expires_on': "4600",
47+
'extended_expires_on': "4600",
48+
'home_account_id': "uid.utid",
49+
'realm': 'contoso',
50+
'secret': 'an access token',
51+
'target': 's2 s1 s3',
52+
},
53+
self.cache._cache["AccessToken"].get(
54+
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3')
55+
)
56+
self.assertEqual(
57+
{
58+
'client_id': 'my_client_id',
59+
'credential_type': 'RefreshToken',
60+
'environment': 'login.example.com',
61+
'home_account_id': "uid.utid",
62+
'secret': 'a refresh token',
63+
'target': 's2 s1 s3',
64+
},
65+
self.cache._cache["RefreshToken"].get(
66+
'uid.utid-login.example.com-refreshtoken-my_client_id--s2 s1 s3')
67+
)
68+
self.assertEqual(
69+
{
70+
'home_account_id': "uid.utid",
71+
'environment': 'login.example.com',
72+
'realm': 'contoso',
73+
'local_account_id': "object1234",
74+
'username': "John Doe",
75+
'authority_type': "MSSTS",
76+
},
77+
self.cache._cache["Account"].get('uid.utid-login.example.com-contoso')
78+
)
79+
self.assertEqual(
80+
{
81+
'credential_type': 'IdToken',
82+
'secret': id_token,
83+
'home_account_id': "uid.utid",
84+
'environment': 'login.example.com',
85+
'realm': 'contoso',
86+
'client_id': 'my_client_id',
87+
},
88+
self.cache._cache["IdToken"].get(
89+
'uid.utid-login.example.com-idtoken-my_client_id-contoso-')
90+
)
91+
92+
93+
class SerializableTokenCacheTestCase(TokenCacheTestCase):
94+
# Run all inherited test methods, and have extra check in tearDown()
95+
96+
def setUp(self):
97+
self.cache = SerializableTokenCache()
98+
self.cache.deserialize("""
99+
{
100+
"AccessToken": {
101+
"an-entry": {
102+
"foo": "bar"
103+
}
104+
},
105+
"customized": "whatever"
106+
}
107+
""")
108+
109+
def tearDown(self):
110+
state = self.cache.serialize()
111+
logger.debug("serialize() = %s", state)
112+
# Now assert all extended content are kept intact
113+
output = json.loads(state)
114+
self.assertEqual(output.get("customized"), "whatever",
115+
"Undefined cache keys and their values should be intact")
116+
self.assertEqual(
117+
output.get("AccessToken", {}).get("an-entry"), {"foo": "bar"},
118+
"Undefined token keys and their values should be intact")
119+

0 commit comments

Comments
 (0)