@@ -39,6 +39,12 @@ class AuthorityType:
3939 def __init__ (self ):
4040 self ._lock = threading .RLock ()
4141 self ._cache = {}
42+ self .key_makers = {
43+ self .CredentialType .REFRESH_TOKEN : self ._build_rt_key ,
44+ self .CredentialType .ACCESS_TOKEN : self ._build_at_key ,
45+ self .CredentialType .ID_TOKEN : self ._build_idt_key ,
46+ self .CredentialType .ACCOUNT : self ._build_account_key ,
47+ }
4248
4349 def find (self , credential_type , target = None , query = None ):
4450 target = target or []
@@ -83,14 +89,9 @@ def add(self, event, now=None):
8389 with self ._lock :
8490
8591 if access_token :
86- key = "-" .join ([
87- home_account_id or "" ,
88- environment or "" ,
89- self .CredentialType .ACCESS_TOKEN ,
90- event .get ("client_id" , "" ),
91- realm or "" ,
92- target ,
93- ]).lower ()
92+ key = self ._build_at_key (
93+ home_account_id , environment , event .get ("client_id" , "" ),
94+ realm , target )
9495 now = time .time () if now is None else now
9596 expires_in = response .get ("expires_in" , 3599 )
9697 self ._cache .setdefault (self .CredentialType .ACCESS_TOKEN , {})[key ] = {
@@ -110,11 +111,7 @@ def add(self, event, now=None):
110111 if client_info :
111112 decoded_id_token = json .loads (
112113 base64decode (id_token .split ('.' )[1 ])) if id_token else {}
113- key = "-" .join ([
114- home_account_id or "" ,
115- environment or "" ,
116- realm or "" ,
117- ]).lower ()
114+ key = self ._build_account_key (home_account_id , environment , realm )
118115 self ._cache .setdefault (self .CredentialType .ACCOUNT , {})[key ] = {
119116 "home_account_id" : home_account_id ,
120117 "environment" : environment ,
@@ -129,14 +126,8 @@ def add(self, event, now=None):
129126 }
130127
131128 if id_token :
132- key = "-" .join ([
133- home_account_id or "" ,
134- environment or "" ,
135- self .CredentialType .ID_TOKEN ,
136- event .get ("client_id" , "" ),
137- realm or "" ,
138- "" # Albeit irrelevant, schema requires an empty scope here
139- ]).lower ()
129+ key = self ._build_idt_key (
130+ home_account_id , environment , event .get ("client_id" , "" ), realm )
140131 self ._cache .setdefault (self .CredentialType .ID_TOKEN , {})[key ] = {
141132 "credential_type" : self .CredentialType .ID_TOKEN ,
142133 "secret" : id_token ,
@@ -170,6 +161,24 @@ def add(self, event, now=None):
170161 "family_id" : response .get ("foci" ), # None is also valid
171162 }
172163
164+ def modify (self , credential_type , old_entry , new_key_value_pairs = None ):
165+ # Modify the specified old_entry with new_key_value_pairs,
166+ # or remove the old_entry if the new_key_value_pairs is None.
167+
168+ # This helper exists to consolidate all token modify/remove behaviors,
169+ # so that the sub-classes will have only one method to work on,
170+ # instead of patching a pair of update_xx() and remove_xx() per type.
171+ # You can monkeypatch self.key_makers to support more types on-the-fly.
172+ key = self .key_makers [credential_type ](** old_entry )
173+ with self ._lock :
174+ if new_key_value_pairs : # Update with them
175+ entries = self ._cache .setdefault (credential_type , {})
176+ entry = entries .get (key , {}) # key usually exists, but we'll survive its absence
177+ entry .update (new_key_value_pairs )
178+ else : # Remove old_entry
179+ self ._cache .setdefault (credential_type , {}).pop (key , None )
180+
181+
173182 @staticmethod
174183 def _build_appmetadata_key (environment , client_id ):
175184 return "appmetadata-{}-{}" .format (environment or "" , client_id or "" )
@@ -178,7 +187,7 @@ def _build_appmetadata_key(environment, client_id):
178187 def _build_rt_key (
179188 cls ,
180189 home_account_id = None , environment = None , client_id = None , target = None ,
181- ** ignored ):
190+ ** ignored_payload_from_a_real_token ):
182191 return "-" .join ([
183192 home_account_id or "" ,
184193 environment or "" ,
@@ -189,16 +198,61 @@ def _build_rt_key(
189198 ]).lower ()
190199
191200 def remove_rt (self , rt_item ):
192- key = self ._build_rt_key (** rt_item )
193- with self ._lock :
194- self ._cache .setdefault (self .CredentialType .REFRESH_TOKEN , {}).pop (key , None )
201+ assert rt_item .get ("credential_type" ) == self .CredentialType .REFRESH_TOKEN
202+ return self .modify (self .CredentialType .REFRESH_TOKEN , rt_item )
195203
196204 def update_rt (self , rt_item , new_rt ):
197- key = self ._build_rt_key (** rt_item )
198- with self ._lock :
199- RTs = self ._cache .setdefault (self .CredentialType .REFRESH_TOKEN , {})
200- rt = RTs .get (key , {}) # key usually exists, but we'll survive its absence
201- rt ["secret" ] = new_rt
205+ assert rt_item .get ("credential_type" ) == self .CredentialType .REFRESH_TOKEN
206+ return self .modify (
207+ self .CredentialType .REFRESH_TOKEN , rt_item , {"secret" : new_rt })
208+
209+ @classmethod
210+ def _build_at_key (cls ,
211+ home_account_id = None , environment = None , client_id = None ,
212+ realm = None , target = None , ** ignored_payload_from_a_real_token ):
213+ return "-" .join ([
214+ home_account_id or "" ,
215+ environment or "" ,
216+ cls .CredentialType .ACCESS_TOKEN ,
217+ client_id ,
218+ realm or "" ,
219+ target or "" ,
220+ ]).lower ()
221+
222+ def remove_at (self , at_item ):
223+ assert at_item .get ("credential_type" ) == self .CredentialType .ACCESS_TOKEN
224+ return self .modify (self .CredentialType .ACCESS_TOKEN , at_item )
225+
226+ @classmethod
227+ def _build_idt_key (cls ,
228+ home_account_id = None , environment = None , client_id = None , realm = None ,
229+ ** ignored_payload_from_a_real_token ):
230+ return "-" .join ([
231+ home_account_id or "" ,
232+ environment or "" ,
233+ cls .CredentialType .ID_TOKEN ,
234+ client_id or "" ,
235+ realm or "" ,
236+ "" # Albeit irrelevant, schema requires an empty scope here
237+ ]).lower ()
238+
239+ def remove_idt (self , idt_item ):
240+ assert idt_item .get ("credential_type" ) == self .CredentialType .ID_TOKEN
241+ return self .modify (self .CredentialType .ID_TOKEN , idt_item )
242+
243+ @classmethod
244+ def _build_account_key (cls ,
245+ home_account_id = None , environment = None , realm = None ,
246+ ** ignored_payload_from_a_real_entry ):
247+ return "-" .join ([
248+ home_account_id or "" ,
249+ environment or "" ,
250+ realm or "" ,
251+ ]).lower ()
252+
253+ def remove_account (self , account_item ):
254+ assert "authority_type" in account_item
255+ return self .modify (self .CredentialType .ACCOUNT , account_item )
202256
203257
204258class SerializableTokenCache (TokenCache ):
@@ -221,7 +275,7 @@ class SerializableTokenCache(TokenCache):
221275 ...
222276
223277 :var bool has_state_changed:
224- Indicates whether the cache state has changed since last
278+ Indicates whether the cache state in the memory has changed since last
225279 :func:`~serialize` or :func:`~deserialize` call.
226280 """
227281 has_state_changed = False
@@ -230,12 +284,9 @@ def add(self, event, **kwargs):
230284 super (SerializableTokenCache , self ).add (event , ** kwargs )
231285 self .has_state_changed = True
232286
233- def remove_rt (self , rt_item ):
234- super (SerializableTokenCache , self ).remove_rt (rt_item )
235- self .has_state_changed = True
236-
237- def update_rt (self , rt_item , new_rt ):
238- super (SerializableTokenCache , self ).update_rt (rt_item , new_rt )
287+ def modify (self , credential_type , old_entry , new_key_value_pairs = None ):
288+ super (SerializableTokenCache , self ).modify (
289+ credential_type , old_entry , new_key_value_pairs )
239290 self .has_state_changed = True
240291
241292 def deserialize (self , state ):
0 commit comments