|
1 | 1 | import json |
2 | 2 | import threading |
3 | 3 | import time |
| 4 | +from typing import Optional # Needed in Python 3.7 & 3.8 |
4 | 5 | import logging |
5 | 6 | import warnings |
6 | 7 |
|
@@ -39,6 +40,25 @@ class AuthorityType: |
39 | 40 | ADFS = "ADFS" |
40 | 41 | MSSTS = "MSSTS" # MSSTS means AAD v2 for both AAD & MSA |
41 | 42 |
|
| 43 | + _data_to_at: dict[str, str] = { # field_in_data: field_in_cache |
| 44 | + # Store extra data which we explicitly allow, |
| 45 | + # so that we won't accidentally store a user's password etc. |
| 46 | + # It can be used to store for example key_id used in SSH-cert or POP |
| 47 | + } |
| 48 | + _response_to_at: dict[str, str] = { # field_in_response: field_in_cache |
| 49 | + } |
| 50 | + |
| 51 | + def _set( |
| 52 | + self, |
| 53 | + *, |
| 54 | + data_to_at: Optional[dict[str, str]] = None, |
| 55 | + response_to_at: Optional[dict[str, str]] = None, |
| 56 | + ) -> None: |
| 57 | + # This helper should probably be better in __init__(), |
| 58 | + # but there is no easy way for MSAL EX to pick up a kwargs |
| 59 | + self._data_to_at = data_to_at or {} |
| 60 | + self._response_to_at = response_to_at or {} |
| 61 | + |
42 | 62 | def __init__(self): |
43 | 63 | self._lock = threading.RLock() |
44 | 64 | self._cache = {} |
@@ -267,11 +287,14 @@ def __add(self, event, now=None): |
267 | 287 | "expires_on": str(now + expires_in), # Same here |
268 | 288 | "extended_expires_on": str(now + ext_expires_in) # Same here |
269 | 289 | } |
270 | | - at.update({k: data[k] for k in data if k in { |
271 | | - # Also store extra data which we explicitly allow |
272 | | - # So that we won't accidentally store a user's password etc. |
273 | | - "key_id", # It happens in SSH-cert or POP scenario |
274 | | - }}) |
| 290 | + for field_in_resp, field_in_cache in self._response_to_at.items(): |
| 291 | + value = response.get(field_in_resp) |
| 292 | + if value: |
| 293 | + at[field_in_cache] = value |
| 294 | + for field_in_data, field_in_cache in self._data_to_at.items(): |
| 295 | + value = data.get(field_in_data) |
| 296 | + if value: |
| 297 | + at[field_in_cache] = value |
275 | 298 | if "refresh_in" in response: |
276 | 299 | refresh_in = response["refresh_in"] # It is an integer |
277 | 300 | at["refresh_on"] = str(now + refresh_in) # Schema wants a string |
|
0 commit comments