Skip to content

Commit 57b1195

Browse files
committed
TokenCache now have one modify() to rule them all.
1 parent 6362813 commit 57b1195

File tree

1 file changed

+34
-24
lines changed

1 file changed

+34
-24
lines changed

msal/token_cache.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ class AuthorityType:
3939
def __init__(self):
4040
self._lock = threading.RLock()
4141
self._cache = {}
42+
self.key_makers = {
43+
self.CredentialType.REFRESH_TOKEN: self._build_rt_key,
44+
self.CredentialType.ACCESS_TOKEN: self._build_at_key,
45+
self.CredentialType.ID_TOKEN: self._build_idt_key,
46+
self.CredentialType.ACCOUNT: self._build_account_key,
47+
}
4248

4349
def find(self, credential_type, target=None, query=None):
4450
target = target or []
@@ -155,6 +161,24 @@ def add(self, event, now=None):
155161
"family_id": response.get("foci"), # None is also valid
156162
}
157163

164+
def modify(self, credential_type, old_entry, new_key_value_pairs=None):
165+
# Modify the specified old_entry with new_key_value_pairs,
166+
# or remove the old_entry if the new_key_value_pairs is None.
167+
168+
# This helper exists to consolidate all token modify/remove behaviors,
169+
# so that the sub-classes will have only one method to work on,
170+
# instead of patching a pair of update_xx() and remove_xx() per type.
171+
# You can monkeypatch self.key_makers to support more types on-the-fly.
172+
key = self.key_makers[credential_type](**old_entry)
173+
with self._lock:
174+
if new_key_value_pairs: # Update with them
175+
entries = self._cache.setdefault(credential_type, {})
176+
entry = entries.get(key, {}) # key usually exists, but we'll survive its absence
177+
entry.update(new_key_value_pairs)
178+
else: # Remove old_entry
179+
self._cache.setdefault(credential_type, {}).pop(key, None)
180+
181+
158182
@staticmethod
159183
def _build_appmetadata_key(environment, client_id):
160184
return "appmetadata-{}-{}".format(environment or "", client_id or "")
@@ -175,17 +199,12 @@ def _build_rt_key(
175199

176200
def remove_rt(self, rt_item):
177201
assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN
178-
key = self._build_rt_key(**rt_item)
179-
with self._lock:
180-
self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {}).pop(key, None)
202+
return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item)
181203

182204
def update_rt(self, rt_item, new_rt):
183205
assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN
184-
key = self._build_rt_key(**rt_item)
185-
with self._lock:
186-
RTs = self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {})
187-
rt = RTs.get(key, {}) # key usually exists, but we'll survive its absence
188-
rt["secret"] = new_rt
206+
return self.modify(
207+
self.CredentialType.REFRESH_TOKEN, rt_item, {"secret": new_rt})
189208

190209
@classmethod
191210
def _build_at_key(cls,
@@ -202,9 +221,7 @@ def _build_at_key(cls,
202221

203222
def remove_at(self, at_item):
204223
assert at_item.get("credential_type") == self.CredentialType.ACCESS_TOKEN
205-
key = self._build_at_key(**at_item)
206-
with self._lock:
207-
self._cache.setdefault(self.CredentialType.ACCESS_TOKEN, {}).pop(key, None)
224+
return self.modify(self.CredentialType.ACCESS_TOKEN, at_item)
208225

209226
@classmethod
210227
def _build_idt_key(cls,
@@ -221,9 +238,7 @@ def _build_idt_key(cls,
221238

222239
def remove_idt(self, idt_item):
223240
assert idt_item.get("credential_type") == self.CredentialType.ID_TOKEN
224-
key = self._build_idt_key(**idt_item)
225-
with self._lock:
226-
self._cache.setdefault(self.CredentialType.ID_TOKEN, {}).pop(key, None)
241+
return self.modify(self.CredentialType.ID_TOKEN, idt_item)
227242

228243
@classmethod
229244
def _build_account_key(cls,
@@ -237,9 +252,7 @@ def _build_account_key(cls,
237252

238253
def remove_account(self, account_item):
239254
assert "authority_type" in account_item
240-
key = self._build_account_key(**account_item)
241-
with self._lock:
242-
self._cache.setdefault(self.CredentialType.ACCOUNT, {}).pop(key, None)
255+
return self.modify(self.CredentialType.ACCOUNT, account_item)
243256

244257

245258
class SerializableTokenCache(TokenCache):
@@ -262,7 +275,7 @@ class SerializableTokenCache(TokenCache):
262275
...
263276
264277
:var bool has_state_changed:
265-
Indicates whether the cache state has changed since last
278+
Indicates whether the cache state in the memory has changed since last
266279
:func:`~serialize` or :func:`~deserialize` call.
267280
"""
268281
has_state_changed = False
@@ -271,12 +284,9 @@ def add(self, event, **kwargs):
271284
super(SerializableTokenCache, self).add(event, **kwargs)
272285
self.has_state_changed = True
273286

274-
def remove_rt(self, rt_item):
275-
super(SerializableTokenCache, self).remove_rt(rt_item)
276-
self.has_state_changed = True
277-
278-
def update_rt(self, rt_item, new_rt):
279-
super(SerializableTokenCache, self).update_rt(rt_item, new_rt)
287+
def modify(self, credential_type, old_entry, new_key_value_pairs=None):
288+
super(SerializableTokenCache, self).modify(
289+
credential_type, old_entry, new_key_value_pairs)
280290
self.has_state_changed = True
281291

282292
def deserialize(self, state):

0 commit comments

Comments
 (0)