Skip to content

Commit 167e954

Browse files
committed
Remove all AT, RT, FRT belongs to current account
1 parent a6dbbff commit 167e954

File tree

3 files changed

+107
-49
lines changed

3 files changed

+107
-49
lines changed

msal/application.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -280,36 +280,23 @@ def _get_authority_aliases(self, instance):
280280
return [alias for alias in group if alias != instance]
281281
return []
282282

283-
def sign_out(self, account):
283+
def remove_account(self, home_account):
284284
"""Remove all relevant RTs and ATs from token cache"""
285285
owned_by_account = {
286-
"environment": account["environment"],
287-
"home_account_id": (account or {}).get("home_account_id"),}
288-
289-
owned_by_account_and_app = dict(owned_by_account, client=self.client_id)
290-
for rt in self.token_cache.find( # Remove RTs
291-
TokenCache.CredentialType.REFRESH_TOKEN,
292-
query=owned_by_account_and_app):
286+
"environment": home_account["environment"],
287+
"home_account_id": home_account["home_account_id"],} # realm-independent
288+
for rt in self.token_cache.find( # Remove RTs, and RTs are realm-independent
289+
TokenCache.CredentialType.REFRESH_TOKEN, query=owned_by_account):
293290
self.token_cache.remove_rt(rt)
294-
for at in self.token_cache.find( # Remove ATs
295-
TokenCache.CredentialType.ACCESS_TOKEN,
296-
query=owned_by_account_and_app): # regardless of realm
297-
self.token_cache.remove_at(at) # TODO
298-
299-
app_metadata = self._get_app_metadata(account["environment"])
300-
if app_metadata.get("family_id"): # Now let's settle family business
301-
for rt in self.token_cache.find( # Remove FRTs
302-
TokenCache.CredentialType.REFRESH_TOKEN, query=dict(
303-
owned_by_account,
304-
family_id=app_metadata["family_id"])):
305-
self.token_cache.remove_rt(rt)
306-
for sibling_app in self.token_cache.find( # Remove siblings' ATs
307-
TokenCache.CredentialType.APP_METADATA,
308-
query={"family_id": app_metadata.get["family_id"]}):
309-
for at in self.token_cache.find( # Remove ATs, regardless of realm
310-
TokenCache.CredentialType.ACCESS_TOKEN, query=dict(
311-
owned_by_account, client_id=sibling_app["client_id"])):
312-
self.token_cache.remove_at(at) # TODO
291+
for at in self.token_cache.find( # Remove ATs, regardless of realm
292+
TokenCache.CredentialType.ACCESS_TOKEN, query=owned_by_account):
293+
self.token_cache.remove_at(at)
294+
for idt in self.token_cache.find( # Remove IDTs, regardless of realm
295+
TokenCache.CredentialType.ID_TOKEN, query=owned_by_account):
296+
self.token_cache.remove_idt(idt)
297+
for a in self.token_cache.find( # Remove Accounts, regardless of realm
298+
TokenCache.CredentialType.ACCOUNT, query=owned_by_account):
299+
self.token_cache.remove_account(a)
313300

314301
def acquire_token_silent(
315302
self,

msal/token_cache.py

Lines changed: 63 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,9 @@ def add(self, event, now=None):
8383
with self._lock:
8484

8585
if access_token:
86-
key = "-".join([
87-
home_account_id or "",
88-
environment or "",
89-
self.CredentialType.ACCESS_TOKEN,
90-
event.get("client_id", ""),
91-
realm or "",
92-
target,
93-
]).lower()
86+
key = self._build_at_key(
87+
home_account_id, environment, event.get("client_id", ""),
88+
realm, target)
9489
now = time.time() if now is None else now
9590
expires_in = response.get("expires_in", 3599)
9691
self._cache.setdefault(self.CredentialType.ACCESS_TOKEN, {})[key] = {
@@ -110,11 +105,7 @@ def add(self, event, now=None):
110105
if client_info:
111106
decoded_id_token = json.loads(
112107
base64decode(id_token.split('.')[1])) if id_token else {}
113-
key = "-".join([
114-
home_account_id or "",
115-
environment or "",
116-
realm or "",
117-
]).lower()
108+
key = self._build_account_key(home_account_id, environment, realm)
118109
self._cache.setdefault(self.CredentialType.ACCOUNT, {})[key] = {
119110
"home_account_id": home_account_id,
120111
"environment": environment,
@@ -129,14 +120,8 @@ def add(self, event, now=None):
129120
}
130121

131122
if id_token:
132-
key = "-".join([
133-
home_account_id or "",
134-
environment or "",
135-
self.CredentialType.ID_TOKEN,
136-
event.get("client_id", ""),
137-
realm or "",
138-
"" # Albeit irrelevant, schema requires an empty scope here
139-
]).lower()
123+
key = self._build_idt_key(
124+
home_account_id, environment, event.get("client_id", ""), realm)
140125
self._cache.setdefault(self.CredentialType.ID_TOKEN, {})[key] = {
141126
"credential_type": self.CredentialType.ID_TOKEN,
142127
"secret": id_token,
@@ -178,7 +163,7 @@ def _build_appmetadata_key(environment, client_id):
178163
def _build_rt_key(
179164
cls,
180165
home_account_id=None, environment=None, client_id=None, target=None,
181-
**ignored):
166+
**ignored_payload_from_a_real_token):
182167
return "-".join([
183168
home_account_id or "",
184169
environment or "",
@@ -189,17 +174,73 @@ def _build_rt_key(
189174
]).lower()
190175

191176
def remove_rt(self, rt_item):
177+
assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN
192178
key = self._build_rt_key(**rt_item)
193179
with self._lock:
194180
self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {}).pop(key, None)
195181

196182
def update_rt(self, rt_item, new_rt):
183+
assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN
197184
key = self._build_rt_key(**rt_item)
198185
with self._lock:
199186
RTs = self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {})
200187
rt = RTs.get(key, {}) # key usually exists, but we'll survive its absence
201188
rt["secret"] = new_rt
202189

190+
@classmethod
191+
def _build_at_key(cls,
192+
home_account_id=None, environment=None, client_id=None,
193+
realm=None, target=None, **ignored_payload_from_a_real_token):
194+
return "-".join([
195+
home_account_id or "",
196+
environment or "",
197+
cls.CredentialType.ACCESS_TOKEN,
198+
client_id,
199+
realm or "",
200+
target or "",
201+
]).lower()
202+
203+
def remove_at(self, at_item):
204+
assert at_item.get("credential_type") == self.CredentialType.ACCESS_TOKEN
205+
key = self._build_at_key(**at_item)
206+
with self._lock:
207+
self._cache.setdefault(self.CredentialType.ACCESS_TOKEN, {}).pop(key, None)
208+
209+
@classmethod
210+
def _build_idt_key(cls,
211+
home_account_id=None, environment=None, client_id=None, realm=None,
212+
**ignored_payload_from_a_real_token):
213+
return "-".join([
214+
home_account_id or "",
215+
environment or "",
216+
cls.CredentialType.ID_TOKEN,
217+
client_id or "",
218+
realm or "",
219+
"" # Albeit irrelevant, schema requires an empty scope here
220+
]).lower()
221+
222+
def remove_idt(self, idt_item):
223+
assert idt_item.get("credential_type") == self.CredentialType.ID_TOKEN
224+
key = self._build_idt_key(**idt_item)
225+
with self._lock:
226+
self._cache.setdefault(self.CredentialType.ID_TOKEN, {}).pop(key, None)
227+
228+
@classmethod
229+
def _build_account_key(cls,
230+
home_account_id=None, environment=None, realm=None,
231+
**ignored_payload_from_a_real_entry):
232+
return "-".join([
233+
home_account_id or "",
234+
environment or "",
235+
realm or "",
236+
]).lower()
237+
238+
def remove_account(self, account_item):
239+
assert "authority_type" in account_item
240+
key = self._build_account_key(**account_item)
241+
with self._lock:
242+
self._cache.setdefault(self.CredentialType.ACCOUNT, {}).pop(key, None)
243+
203244

204245
class SerializableTokenCache(TokenCache):
205246
"""This serialization can be a starting point to implement your own persistence.

tests/test_application.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ def setUp(self):
179179
"scope": self.scopes,
180180
"token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url),
181181
"response": TokenCacheTestCase.build_response(
182+
access_token="Siblings won't share AT. test_remove_account() will.",
183+
id_token=TokenCacheTestCase.build_id_token(),
182184
uid=self.uid, utid=self.utid, refresh_token=self.frt, foci="1"),
183185
}) # The add(...) helper populates correct home_account_id for future searching
184186

@@ -239,6 +241,34 @@ def tester(url, data=None, **kwargs):
239241

240242
# Will not test scenario of app leaving family. Per specs, it won't happen.
241243

244+
def test_get_remove_account(self):
245+
logger.debug("%s.cache = %s", self.id(), self.cache.serialize())
246+
app = ClientApplication(
247+
"family_app_2", authority=self.authority_url, token_cache=self.cache)
248+
account = app.get_accounts()[0]
249+
mine = {"home_account_id": account["home_account_id"]}
250+
251+
self.assertNotEqual([], self.cache.find(
252+
self.cache.CredentialType.ACCESS_TOKEN, query=mine))
253+
self.assertNotEqual([], self.cache.find(
254+
self.cache.CredentialType.REFRESH_TOKEN, query=mine))
255+
self.assertNotEqual([], self.cache.find(
256+
self.cache.CredentialType.ID_TOKEN, query=mine))
257+
self.assertNotEqual([], self.cache.find(
258+
self.cache.CredentialType.ACCOUNT, query=mine))
259+
260+
app.remove_account(account)
261+
262+
self.assertEqual([], self.cache.find(
263+
self.cache.CredentialType.ACCESS_TOKEN, query=mine))
264+
self.assertEqual([], self.cache.find(
265+
self.cache.CredentialType.REFRESH_TOKEN, query=mine))
266+
self.assertEqual([], self.cache.find(
267+
self.cache.CredentialType.ID_TOKEN, query=mine))
268+
self.assertEqual([], self.cache.find(
269+
self.cache.CredentialType.ACCOUNT, query=mine))
270+
271+
242272
class TestClientApplicationForAuthorityMigration(unittest.TestCase):
243273

244274
@classmethod

0 commit comments

Comments
 (0)