Skip to content

Commit 30affeb

Browse files
committed
Now get_accounts() ensures proper authority type
1 parent cf9e30f commit 30affeb

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

msal/application.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,17 +236,21 @@ def get_accounts(self, username=None):
236236
Your app can choose to display those information to end user,
237237
and allow user to choose one of his/her accounts to proceed.
238238
"""
239-
# The following implementation finds accounts only from saved accounts,
240-
# but does NOT correlate them with saved RTs. It probably won't matter,
241-
# because in MSAL universe, there are always Accounts and RTs together.
242-
accounts = self.token_cache.find(
239+
accounts = [a for a in self.token_cache.find( # Find all useful accounts
243240
self.token_cache.CredentialType.ACCOUNT,
244241
query={"environment": self.authority.instance})
242+
if a["authority_type"] in (
243+
TokenCache.AuthorityType.ADFS, TokenCache.AuthorityType.MSSTS)]
245244
if username:
246245
# Federated account["username"] from AAD could contain mixed case
247246
lowercase_username = username.lower()
248247
accounts = [a for a in accounts
249248
if a["username"].lower() == lowercase_username]
249+
# Does not further filter by existing RTs here. It probably won't matter.
250+
# Because in most cases Accounts and RTs co-exist.
251+
# Even in the rare case when an RT is revoked and then removed,
252+
# acquire_token_silent() would then yield no result,
253+
# apps would fall back to other acquire methods. This is the standard pattern.
250254
return accounts
251255

252256
def acquire_token_silent(

msal/token_cache.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ class CredentialType:
3131
ACCOUNT = "Account" # Not exactly a credential type, but we put it here
3232
ID_TOKEN = "IdToken"
3333

34+
class AuthorityType:
35+
ADFS = "ADFS"
36+
MSSTS = "MSSTS" # MSSTS means AAD v2 for both AAD & MSA
37+
3438
def __init__(self):
3539
self._lock = threading.RLock()
3640
self._cache = {}
@@ -118,8 +122,8 @@ def add(self, event, now=None):
118122
"oid", decoded_id_token.get("sub")),
119123
"username": decoded_id_token.get("preferred_username"),
120124
"authority_type":
121-
"ADFS" if realm == "adfs"
122-
else "MSSTS", # MSSTS means AAD v2 for both AAD & MSA
125+
self.AuthorityType.ADFS if realm == "adfs"
126+
else self.AuthorityType.MSSTS,
123127
# "client_info": response.get("client_info"), # Optional
124128
}
125129

0 commit comments

Comments
 (0)