@@ -88,20 +88,69 @@ def __init__(self):
8888 "appmetadata-{}-{}" .format (environment or "" , client_id or "" ),
8989 }
9090
91- def find (self , credential_type , target = None , query = None ):
92- target = target or []
91+ def _get_access_token (
92+ self ,
93+ home_account_id , environment , client_id , realm , target , # Together they form a compound key
94+ default = None ,
95+ ): # O(1)
96+ return self ._get (
97+ self .CredentialType .ACCESS_TOKEN ,
98+ self .key_makers [TokenCache .CredentialType .ACCESS_TOKEN ](
99+ home_account_id = home_account_id ,
100+ environment = environment ,
101+ client_id = client_id ,
102+ realm = realm ,
103+ target = " " .join (target ),
104+ ),
105+ default = default )
106+
107+ def _get_app_metadata (self , environment , client_id , default = None ): # O(1)
108+ return self ._get (
109+ self .CredentialType .APP_METADATA ,
110+ self .key_makers [TokenCache .CredentialType .APP_METADATA ](
111+ environment = environment ,
112+ client_id = client_id ,
113+ ),
114+ default = default )
115+
116+ def _get (self , credential_type , key , default = None ): # O(1)
117+ with self ._lock :
118+ return self ._cache .get (credential_type , {}).get (key , default )
119+
120+ def _find (self , credential_type , target = None , query = None ): # O(n) generator
121+ """Returns a generator of matching entries.
122+
123+ It is O(1) for AT hits, and O(n) for other types.
124+ Note that it holds a lock during the entire search.
125+ """
126+ target = sorted (target or []) # Match the order sorted by add()
93127 assert isinstance (target , list ), "Invalid parameter type"
128+
129+ preferred_result = None
130+ if (credential_type == self .CredentialType .ACCESS_TOKEN
131+ and "home_account_id" in query and "environment" in query
132+ and "client_id" in query and "realm" in query and target
133+ ): # Special case for O(1) AT lookup
134+ preferred_result = self ._get_access_token (
135+ query ["home_account_id" ], query ["environment" ],
136+ query ["client_id" ], query ["realm" ], target )
137+ if preferred_result :
138+ yield preferred_result
139+
94140 target_set = set (target )
95141 with self ._lock :
96142 # Since the target inside token cache key is (per schema) unsorted,
97143 # there is no point to attempt an O(1) key-value search here.
98144 # So we always do an O(n) in-memory search.
99- return [entry
100- for entry in self ._cache .get (credential_type , {}).values ()
101- if is_subdict_of (query or {}, entry )
102- and (target_set <= set (entry .get ("target" , "" ).split ())
103- if target else True )
104- ]
145+ for entry in self ._cache .get (credential_type , {}).values ():
146+ if is_subdict_of (query or {}, entry ) and (
147+ target_set <= set (entry .get ("target" , "" ).split ())
148+ if target else True ):
149+ if entry != preferred_result : # Avoid yielding the same entry twice
150+ yield entry
151+
152+ def find (self , credential_type , target = None , query = None ): # Obsolete. Use _find() instead.
153+ return list (self ._find (credential_type , target = target , query = query ))
105154
106155 def add (self , event , now = None ):
107156 """Handle a token obtaining event, and add tokens into cache."""
@@ -160,7 +209,7 @@ def __add(self, event, now=None):
160209 decode_id_token (id_token , client_id = event ["client_id" ]) if id_token else {})
161210 client_info , home_account_id = self .__parse_account (response , id_token_claims )
162211
163- target = ' ' .join (event .get ("scope" ) or []) # Per schema, we don't sort it
212+ target = ' ' .join (sorted ( event .get ("scope" ) or [])) # Schema should have required sorting
164213
165214 with self ._lock :
166215 now = int (time .time () if now is None else now )
0 commit comments