Skip to content

Commit 5f0befb

Browse files
authored
Merge pull request #41 from AzureAD/signout-family
Add remove_account() API
2 parents f76f3c3 + 57b1195 commit 5f0befb

File tree

3 files changed

+171
-42
lines changed

3 files changed

+171
-42
lines changed

msal/application.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,49 @@ def _get_authority_aliases(self, instance):
280280
return [alias for alias in group if alias != instance]
281281
return []
282282

283+
def remove_account(self, account):
284+
"""Sign me out and forget me from token cache"""
285+
self._forget_me(account)
286+
287+
def _sign_out(self, home_account):
288+
# Remove all relevant RTs and ATs from token cache
289+
owned_by_home_account = {
290+
"environment": home_account["environment"],
291+
"home_account_id": home_account["home_account_id"],} # realm-independent
292+
app_metadata = self._get_app_metadata(home_account["environment"])
293+
# Remove RTs/FRTs, and they are realm-independent
294+
for rt in [rt for rt in self.token_cache.find(
295+
TokenCache.CredentialType.REFRESH_TOKEN, query=owned_by_home_account)
296+
# Do RT's app ownership check as a precaution, in case family apps
297+
# and 3rd-party apps share same token cache, although they should not.
298+
if rt["client_id"] == self.client_id or (
299+
app_metadata.get("family_id") # Now let's settle family business
300+
and rt.get("family_id") == app_metadata["family_id"])
301+
]:
302+
self.token_cache.remove_rt(rt)
303+
for at in self.token_cache.find( # Remove ATs
304+
# Regardless of realm, b/c we've removed realm-independent RTs anyway
305+
TokenCache.CredentialType.ACCESS_TOKEN, query=owned_by_home_account):
306+
# To avoid the complexity of locating sibling family app's AT,
307+
# we skip AT's app ownership check.
308+
# It means ATs for other apps will also be removed, it is OK because:
309+
# * non-family apps are not supposed to share token cache to begin with;
310+
# * Even if it happens, we keep other app's RT already, so SSO still works
311+
self.token_cache.remove_at(at)
312+
313+
def _forget_me(self, home_account):
314+
# It implies signout, and then also remove all relevant accounts and IDTs
315+
self._sign_out(home_account)
316+
owned_by_home_account = {
317+
"environment": home_account["environment"],
318+
"home_account_id": home_account["home_account_id"],} # realm-independent
319+
for idt in self.token_cache.find( # Remove IDTs, regardless of realm
320+
TokenCache.CredentialType.ID_TOKEN, query=owned_by_home_account):
321+
self.token_cache.remove_idt(idt)
322+
for a in self.token_cache.find( # Remove Accounts, regardless of realm
323+
TokenCache.CredentialType.ACCOUNT, query=owned_by_home_account):
324+
self.token_cache.remove_account(a)
325+
283326
def acquire_token_silent(
284327
self,
285328
scopes, # type: List[str]
@@ -364,10 +407,7 @@ def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
364407
"home_account_id": (account or {}).get("home_account_id"),
365408
# "realm": authority.tenant, # AAD RTs are tenant-independent
366409
}
367-
apps = self.token_cache.find( # Use find(), rather than token_cache.get(...)
368-
TokenCache.CredentialType.APP_METADATA, query={
369-
"environment": authority.instance, "client_id": self.client_id})
370-
app_metadata = apps[0] if apps else {}
410+
app_metadata = self._get_app_metadata(authority.instance)
371411
if not app_metadata: # Meaning this app is now used for the first time.
372412
# When/if we have a way to directly detect current app's family,
373413
# we'll rewrite this block, to support multiple families.
@@ -396,6 +436,12 @@ def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
396436
return self._acquire_token_silent_by_finding_specific_refresh_token(
397437
authority, scopes, dict(query, client_id=self.client_id), **kwargs)
398438

439+
def _get_app_metadata(self, environment):
440+
apps = self.token_cache.find( # Use find(), rather than token_cache.get(...)
441+
TokenCache.CredentialType.APP_METADATA, query={
442+
"environment": environment, "client_id": self.client_id})
443+
return apps[0] if apps else {}
444+
399445
def _acquire_token_silent_by_finding_specific_refresh_token(
400446
self, authority, scopes, query,
401447
rt_remover=None, break_condition=lambda response: False, **kwargs):

msal/token_cache.py

Lines changed: 88 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ class AuthorityType:
3939
def __init__(self):
4040
self._lock = threading.RLock()
4141
self._cache = {}
42+
self.key_makers = {
43+
self.CredentialType.REFRESH_TOKEN: self._build_rt_key,
44+
self.CredentialType.ACCESS_TOKEN: self._build_at_key,
45+
self.CredentialType.ID_TOKEN: self._build_idt_key,
46+
self.CredentialType.ACCOUNT: self._build_account_key,
47+
}
4248

4349
def find(self, credential_type, target=None, query=None):
4450
target = target or []
@@ -83,14 +89,9 @@ def add(self, event, now=None):
8389
with self._lock:
8490

8591
if access_token:
86-
key = "-".join([
87-
home_account_id or "",
88-
environment or "",
89-
self.CredentialType.ACCESS_TOKEN,
90-
event.get("client_id", ""),
91-
realm or "",
92-
target,
93-
]).lower()
92+
key = self._build_at_key(
93+
home_account_id, environment, event.get("client_id", ""),
94+
realm, target)
9495
now = time.time() if now is None else now
9596
expires_in = response.get("expires_in", 3599)
9697
self._cache.setdefault(self.CredentialType.ACCESS_TOKEN, {})[key] = {
@@ -110,11 +111,7 @@ def add(self, event, now=None):
110111
if client_info:
111112
decoded_id_token = json.loads(
112113
base64decode(id_token.split('.')[1])) if id_token else {}
113-
key = "-".join([
114-
home_account_id or "",
115-
environment or "",
116-
realm or "",
117-
]).lower()
114+
key = self._build_account_key(home_account_id, environment, realm)
118115
self._cache.setdefault(self.CredentialType.ACCOUNT, {})[key] = {
119116
"home_account_id": home_account_id,
120117
"environment": environment,
@@ -129,14 +126,8 @@ def add(self, event, now=None):
129126
}
130127

131128
if id_token:
132-
key = "-".join([
133-
home_account_id or "",
134-
environment or "",
135-
self.CredentialType.ID_TOKEN,
136-
event.get("client_id", ""),
137-
realm or "",
138-
"" # Albeit irrelevant, schema requires an empty scope here
139-
]).lower()
129+
key = self._build_idt_key(
130+
home_account_id, environment, event.get("client_id", ""), realm)
140131
self._cache.setdefault(self.CredentialType.ID_TOKEN, {})[key] = {
141132
"credential_type": self.CredentialType.ID_TOKEN,
142133
"secret": id_token,
@@ -170,6 +161,24 @@ def add(self, event, now=None):
170161
"family_id": response.get("foci"), # None is also valid
171162
}
172163

164+
def modify(self, credential_type, old_entry, new_key_value_pairs=None):
165+
# Modify the specified old_entry with new_key_value_pairs,
166+
# or remove the old_entry if the new_key_value_pairs is None.
167+
168+
# This helper exists to consolidate all token modify/remove behaviors,
169+
# so that the sub-classes will have only one method to work on,
170+
# instead of patching a pair of update_xx() and remove_xx() per type.
171+
# You can monkeypatch self.key_makers to support more types on-the-fly.
172+
key = self.key_makers[credential_type](**old_entry)
173+
with self._lock:
174+
if new_key_value_pairs: # Update with them
175+
entries = self._cache.setdefault(credential_type, {})
176+
entry = entries.get(key, {}) # key usually exists, but we'll survive its absence
177+
entry.update(new_key_value_pairs)
178+
else: # Remove old_entry
179+
self._cache.setdefault(credential_type, {}).pop(key, None)
180+
181+
173182
@staticmethod
174183
def _build_appmetadata_key(environment, client_id):
175184
return "appmetadata-{}-{}".format(environment or "", client_id or "")
@@ -178,7 +187,7 @@ def _build_appmetadata_key(environment, client_id):
178187
def _build_rt_key(
179188
cls,
180189
home_account_id=None, environment=None, client_id=None, target=None,
181-
**ignored):
190+
**ignored_payload_from_a_real_token):
182191
return "-".join([
183192
home_account_id or "",
184193
environment or "",
@@ -189,16 +198,61 @@ def _build_rt_key(
189198
]).lower()
190199

191200
def remove_rt(self, rt_item):
192-
key = self._build_rt_key(**rt_item)
193-
with self._lock:
194-
self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {}).pop(key, None)
201+
assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN
202+
return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item)
195203

196204
def update_rt(self, rt_item, new_rt):
197-
key = self._build_rt_key(**rt_item)
198-
with self._lock:
199-
RTs = self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {})
200-
rt = RTs.get(key, {}) # key usually exists, but we'll survive its absence
201-
rt["secret"] = new_rt
205+
assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN
206+
return self.modify(
207+
self.CredentialType.REFRESH_TOKEN, rt_item, {"secret": new_rt})
208+
209+
@classmethod
210+
def _build_at_key(cls,
211+
home_account_id=None, environment=None, client_id=None,
212+
realm=None, target=None, **ignored_payload_from_a_real_token):
213+
return "-".join([
214+
home_account_id or "",
215+
environment or "",
216+
cls.CredentialType.ACCESS_TOKEN,
217+
client_id,
218+
realm or "",
219+
target or "",
220+
]).lower()
221+
222+
def remove_at(self, at_item):
223+
assert at_item.get("credential_type") == self.CredentialType.ACCESS_TOKEN
224+
return self.modify(self.CredentialType.ACCESS_TOKEN, at_item)
225+
226+
@classmethod
227+
def _build_idt_key(cls,
228+
home_account_id=None, environment=None, client_id=None, realm=None,
229+
**ignored_payload_from_a_real_token):
230+
return "-".join([
231+
home_account_id or "",
232+
environment or "",
233+
cls.CredentialType.ID_TOKEN,
234+
client_id or "",
235+
realm or "",
236+
"" # Albeit irrelevant, schema requires an empty scope here
237+
]).lower()
238+
239+
def remove_idt(self, idt_item):
240+
assert idt_item.get("credential_type") == self.CredentialType.ID_TOKEN
241+
return self.modify(self.CredentialType.ID_TOKEN, idt_item)
242+
243+
@classmethod
244+
def _build_account_key(cls,
245+
home_account_id=None, environment=None, realm=None,
246+
**ignored_payload_from_a_real_entry):
247+
return "-".join([
248+
home_account_id or "",
249+
environment or "",
250+
realm or "",
251+
]).lower()
252+
253+
def remove_account(self, account_item):
254+
assert "authority_type" in account_item
255+
return self.modify(self.CredentialType.ACCOUNT, account_item)
202256

203257

204258
class SerializableTokenCache(TokenCache):
@@ -221,7 +275,7 @@ class SerializableTokenCache(TokenCache):
221275
...
222276
223277
:var bool has_state_changed:
224-
Indicates whether the cache state has changed since last
278+
Indicates whether the cache state in the memory has changed since last
225279
:func:`~serialize` or :func:`~deserialize` call.
226280
"""
227281
has_state_changed = False
@@ -230,12 +284,9 @@ def add(self, event, **kwargs):
230284
super(SerializableTokenCache, self).add(event, **kwargs)
231285
self.has_state_changed = True
232286

233-
def remove_rt(self, rt_item):
234-
super(SerializableTokenCache, self).remove_rt(rt_item)
235-
self.has_state_changed = True
236-
237-
def update_rt(self, rt_item, new_rt):
238-
super(SerializableTokenCache, self).update_rt(rt_item, new_rt)
287+
def modify(self, credential_type, old_entry, new_key_value_pairs=None):
288+
super(SerializableTokenCache, self).modify(
289+
credential_type, old_entry, new_key_value_pairs)
239290
self.has_state_changed = True
240291

241292
def deserialize(self, state):

tests/test_application.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,14 @@ def setUp(self):
174174
self.account = {"home_account_id": "{}.{}".format(self.uid, self.utid)}
175175
self.frt = "what the frt"
176176
self.cache = msal.SerializableTokenCache()
177+
self.preexisting_family_app_id = "preexisting_family_app"
177178
self.cache.add({ # Pre-populate a FRT
178-
"client_id": "preexisting_family_app",
179+
"client_id": self.preexisting_family_app_id,
179180
"scope": self.scopes,
180181
"token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url),
181182
"response": TokenCacheTestCase.build_response(
183+
access_token="Siblings won't share AT. test_remove_account() will.",
184+
id_token=TokenCacheTestCase.build_id_token(),
182185
uid=self.uid, utid=self.utid, refresh_token=self.frt, foci="1"),
183186
}) # The add(...) helper populates correct home_account_id for future searching
184187

@@ -239,6 +242,35 @@ def tester(url, data=None, **kwargs):
239242

240243
# Will not test scenario of app leaving family. Per specs, it won't happen.
241244

245+
def test_family_app_remove_account(self):
246+
logger.debug("%s.cache = %s", self.id(), self.cache.serialize())
247+
app = ClientApplication(
248+
self.preexisting_family_app_id,
249+
authority=self.authority_url, token_cache=self.cache)
250+
account = app.get_accounts()[0]
251+
mine = {"home_account_id": account["home_account_id"]}
252+
253+
self.assertNotEqual([], self.cache.find(
254+
self.cache.CredentialType.ACCESS_TOKEN, query=mine))
255+
self.assertNotEqual([], self.cache.find(
256+
self.cache.CredentialType.REFRESH_TOKEN, query=mine))
257+
self.assertNotEqual([], self.cache.find(
258+
self.cache.CredentialType.ID_TOKEN, query=mine))
259+
self.assertNotEqual([], self.cache.find(
260+
self.cache.CredentialType.ACCOUNT, query=mine))
261+
262+
app.remove_account(account)
263+
264+
self.assertEqual([], self.cache.find(
265+
self.cache.CredentialType.ACCESS_TOKEN, query=mine))
266+
self.assertEqual([], self.cache.find(
267+
self.cache.CredentialType.REFRESH_TOKEN, query=mine))
268+
self.assertEqual([], self.cache.find(
269+
self.cache.CredentialType.ID_TOKEN, query=mine))
270+
self.assertEqual([], self.cache.find(
271+
self.cache.CredentialType.ACCOUNT, query=mine))
272+
273+
242274
class TestClientApplicationForAuthorityMigration(unittest.TestCase):
243275

244276
@classmethod

0 commit comments

Comments
 (0)