Skip to content

Commit 9ff7fac

Browse files
committed
Customizable token cache
Customizable data and response to be saved into token cache
1 parent 7db6c2c commit 9ff7fac

File tree

3 files changed

+38
-8
lines changed

3 files changed

+38
-8
lines changed

msal/application.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,10 @@ class ClientApplication(object):
238238
"You can enable broker by following these instructions. "
239239
"https://msal-python.readthedocs.io/en/latest/#publicclientapplication")
240240

241+
_TOKEN_CACHE_DATA: dict[str, str] = { # field_in_data: field_in_cache
242+
"key_id": "key_id", # Some token types (SSH-certs, POP) are bound to a key
243+
}
244+
241245
def __init__(
242246
self, client_id,
243247
client_credential=None, authority=None, validate_authority=True,
@@ -651,6 +655,7 @@ def __init__(
651655

652656
self._decide_broker(allow_broker, enable_pii_log)
653657
self.token_cache = token_cache or TokenCache()
658+
self.token_cache._set(data_to_at=self._TOKEN_CACHE_DATA)
654659
self._region_configured = azure_region
655660
self._region_detected = None
656661
self.client, self._regional_client = self._build_client(
@@ -1528,9 +1533,10 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
15281533
"realm": authority.tenant,
15291534
"home_account_id": (account or {}).get("home_account_id"),
15301535
}
1531-
key_id = kwargs.get("data", {}).get("key_id")
1532-
if key_id: # Some token types (SSH-certs, POP) are bound to a key
1533-
query["key_id"] = key_id
1536+
for field_in_data, field_in_cache in self._TOKEN_CACHE_DATA.items():
1537+
value = kwargs.get("data", {}).get(field_in_data)
1538+
if value:
1539+
query[field_in_cache] = value
15341540
now = time.time()
15351541
refresh_reason = msal.telemetry.AT_ABSENT
15361542
for entry in self.token_cache.search( # A generator allows us to

msal/token_cache.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import threading
33
import time
4+
from typing import Optional # Needed in Python 3.7 & 3.8
45
import logging
56
import warnings
67

@@ -39,6 +40,25 @@ class AuthorityType:
3940
ADFS = "ADFS"
4041
MSSTS = "MSSTS" # MSSTS means AAD v2 for both AAD & MSA
4142

43+
_data_to_at: dict[str, str] = { # field_in_data: field_in_cache
44+
# Store extra data which we explicitly allow,
45+
# so that we won't accidentally store a user's password etc.
46+
# It can be used to store for example key_id used in SSH-cert or POP
47+
}
48+
_response_to_at: dict[str, str] = { # field_in_response: field_in_cache
49+
}
50+
51+
def _set(
52+
self,
53+
*,
54+
data_to_at: Optional[dict[str, str]] = None,
55+
response_to_at: Optional[dict[str, str]] = None,
56+
) -> None:
57+
# This helper should probably be better in __init__(),
58+
# but there is no easy way for MSAL EX to pick up a kwargs
59+
self._data_to_at = data_to_at or {}
60+
self._response_to_at = response_to_at or {}
61+
4262
def __init__(self):
4363
self._lock = threading.RLock()
4464
self._cache = {}
@@ -267,11 +287,14 @@ def __add(self, event, now=None):
267287
"expires_on": str(now + expires_in), # Same here
268288
"extended_expires_on": str(now + ext_expires_in) # Same here
269289
}
270-
at.update({k: data[k] for k in data if k in {
271-
# Also store extra data which we explicitly allow
272-
# So that we won't accidentally store a user's password etc.
273-
"key_id", # It happens in SSH-cert or POP scenario
274-
}})
290+
for field_in_resp, field_in_cache in self._response_to_at.items():
291+
value = response.get(field_in_resp)
292+
if value:
293+
at[field_in_cache] = value
294+
for field_in_data, field_in_cache in self._data_to_at.items():
295+
value = data.get(field_in_data)
296+
if value:
297+
at[field_in_cache] = value
275298
if "refresh_in" in response:
276299
refresh_in = response["refresh_in"] # It is an integer
277300
at["refresh_on"] = str(now + refresh_in) # Schema wants a string

tests/test_token_cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def assertFoundAccessToken(self, *, scopes, query, data=None, now=None):
218218
def _test_data_should_be_saved_and_searchable_in_access_token(self, data):
219219
scopes = ["s2", "s1", "s3"] # Not in particular order
220220
now = 1000
221+
self.cache._set(data_to_at={"key_id": "key_id"})
221222
self.cache.add({
222223
"data": data,
223224
"client_id": "my_client_id",

0 commit comments

Comments
 (0)