Skip to content

Commit 0e28482

Browse files
committed
Customizable token cache
Customizable data and response to be saved into token cache
1 parent 7db6c2c commit 0e28482

File tree

3 files changed

+41
-9
lines changed

3 files changed

+41
-9
lines changed

msal/application.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import annotations
12
import functools
23
import json
34
import time
@@ -238,6 +239,10 @@ class ClientApplication(object):
238239
"You can enable broker by following these instructions. "
239240
"https://msal-python.readthedocs.io/en/latest/#publicclientapplication")
240241

242+
_TOKEN_CACHE_DATA: dict[str, str] = { # field_in_data: field_in_cache
243+
"key_id": "key_id", # Some token types (SSH-certs, POP) are bound to a key
244+
}
245+
241246
def __init__(
242247
self, client_id,
243248
client_credential=None, authority=None, validate_authority=True,
@@ -651,6 +656,7 @@ def __init__(
651656

652657
self._decide_broker(allow_broker, enable_pii_log)
653658
self.token_cache = token_cache or TokenCache()
659+
self.token_cache._set(data_to_at=self._TOKEN_CACHE_DATA)
654660
self._region_configured = azure_region
655661
self._region_detected = None
656662
self.client, self._regional_client = self._build_client(
@@ -1528,9 +1534,10 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
15281534
"realm": authority.tenant,
15291535
"home_account_id": (account or {}).get("home_account_id"),
15301536
}
1531-
key_id = kwargs.get("data", {}).get("key_id")
1532-
if key_id: # Some token types (SSH-certs, POP) are bound to a key
1533-
query["key_id"] = key_id
1537+
for field_in_data, field_in_cache in self._TOKEN_CACHE_DATA.items():
1538+
value = kwargs.get("data", {}).get(field_in_data)
1539+
if value:
1540+
query[field_in_cache] = value
15341541
now = time.time()
15351542
refresh_reason = msal.telemetry.AT_ABSENT
15361543
for entry in self.token_cache.search( # A generator allows us to

msal/token_cache.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
import json
1+
from __future__ import annotations
2+
import json
23
import threading
34
import time
5+
from typing import Optional # Needed in Python 3.7 & 3.8
46
import logging
57
import warnings
68

@@ -39,6 +41,25 @@ class AuthorityType:
3941
ADFS = "ADFS"
4042
MSSTS = "MSSTS" # MSSTS means AAD v2 for both AAD & MSA
4143

44+
_data_to_at: dict[str, str] = { # field_in_data: field_in_cache
45+
# Store extra data which we explicitly allow,
46+
# so that we won't accidentally store a user's password etc.
47+
# It can be used to store for example key_id used in SSH-cert or POP
48+
}
49+
_response_to_at: dict[str, str] = { # field_in_response: field_in_cache
50+
}
51+
52+
def _set(
53+
self,
54+
*,
55+
data_to_at: Optional[dict[str, str]] = None,
56+
response_to_at: Optional[dict[str, str]] = None,
57+
) -> None:
58+
# This helper should probably be better in __init__(),
59+
# but there is no easy way for MSAL EX to pick up a kwargs
60+
self._data_to_at = data_to_at or {}
61+
self._response_to_at = response_to_at or {}
62+
4263
def __init__(self):
4364
self._lock = threading.RLock()
4465
self._cache = {}
@@ -267,11 +288,14 @@ def __add(self, event, now=None):
267288
"expires_on": str(now + expires_in), # Same here
268289
"extended_expires_on": str(now + ext_expires_in) # Same here
269290
}
270-
at.update({k: data[k] for k in data if k in {
271-
# Also store extra data which we explicitly allow
272-
# So that we won't accidentally store a user's password etc.
273-
"key_id", # It happens in SSH-cert or POP scenario
274-
}})
291+
for field_in_resp, field_in_cache in self._response_to_at.items():
292+
value = response.get(field_in_resp)
293+
if value:
294+
at[field_in_cache] = value
295+
for field_in_data, field_in_cache in self._data_to_at.items():
296+
value = data.get(field_in_data)
297+
if value:
298+
at[field_in_cache] = value
275299
if "refresh_in" in response:
276300
refresh_in = response["refresh_in"] # It is an integer
277301
at["refresh_on"] = str(now + refresh_in) # Schema wants a string

tests/test_token_cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def assertFoundAccessToken(self, *, scopes, query, data=None, now=None):
218218
def _test_data_should_be_saved_and_searchable_in_access_token(self, data):
219219
scopes = ["s2", "s1", "s3"] # Not in particular order
220220
now = 1000
221+
self.cache._set(data_to_at={"key_id": "key_id"})
221222
self.cache.add({
222223
"data": data,
223224
"client_id": "my_client_id",

0 commit comments

Comments
 (0)