Skip to content

Commit 6ee2837

Browse files
committed
Refactor MSAL HTTP cache to use JSON format and implement JsonCache class
1 parent 104cb2e commit 6ee2837

File tree

2 files changed

+115
-3
lines changed

2 files changed

+115
-3
lines changed

src/azure-cli-core/azure/cli/core/auth/identity.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(self, authority, tenant_id=None, client_id=None, encrypt=False, use
8282
config_dir = get_config_dir()
8383
self._token_cache_file = os.path.join(config_dir, "msal_token_cache")
8484
self._secret_file = os.path.join(config_dir, "service_principal_entries")
85-
self._msal_http_cache_file = os.path.join(config_dir, "msal_http_cache.bin")
85+
self._msal_http_cache_file = os.path.join(config_dir, "msal_http_cache.json")
8686

8787
# We make _msal_app_instance an instance attribute, instead of a class attribute,
8888
# because MSAL apps can have different tenant IDs.
@@ -131,8 +131,8 @@ def _load_msal_token_cache(self):
131131
return cache
132132

133133
def _load_msal_http_cache(self):
134-
from .binary_cache import BinaryCache
135-
http_cache = BinaryCache(self._msal_http_cache_file)
134+
from .json_cache import JsonCache
135+
http_cache = JsonCache(self._msal_http_cache_file)
136136
return http_cache
137137

138138
@property
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# --------------------------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for license information.
4+
# --------------------------------------------------------------------------------------------
5+
import json
6+
from collections.abc import MutableMapping
7+
8+
from azure.cli.core.decorators import retry
9+
from knack.log import get_logger
10+
from msal.throttled_http_client import NormalizedResponse
11+
12+
logger = get_logger(__name__)
13+
14+
15+
class NormalizedResponseJsonEncoder(json.JSONEncoder):
16+
def default(self, obj):
17+
if isinstance(obj, NormalizedResponse):
18+
return {
19+
"status_code": obj.status_code,
20+
"text": obj.text,
21+
"headers": obj.headers,
22+
}
23+
return super().default(obj)
24+
25+
26+
class JsonCache(MutableMapping):
27+
"""
28+
A simple dict-like class that is backed by a json file. This is designed for the MSAL HTTP cache.
29+
30+
All direct modifications with `__setitem__` and `__delitem__` will save the file.
31+
Indirect modifications should be followed by a call to `save`.
32+
"""
33+
def __init__(self, file_name):
34+
super().__init__()
35+
self.filename = file_name
36+
self.data = {}
37+
self.load()
38+
39+
@retry()
40+
def _load(self):
41+
"""Load cache with retry. If it still fails at last, raise the original exception as-is."""
42+
try:
43+
with open(self.filename, 'r', encoding='utf-8') as f:
44+
data = json.load(f)
45+
response_keys = [key for key in data if key != "_index_"]
46+
for key in response_keys:
47+
try:
48+
response_dict = data[key]
49+
# Reconstruct NormalizedResponse from the stored dict
50+
response = NormalizedResponse.__new__(NormalizedResponse)
51+
response.status_code = response_dict["status_code"]
52+
response.text = response_dict["text"]
53+
response.headers = response_dict["headers"]
54+
data[key] = response
55+
except KeyError as e:
56+
logger.debug("Failed to reconstruct NormalizedResponse for key %s: %s", key, e)
57+
# If reconstruction fails, remove the entry from cache
58+
del data[key]
59+
return data
60+
except FileNotFoundError:
61+
# The cache file has not been created. This is expected. No need to retry.
62+
logger.debug("%s not found. Using a fresh one.", self.filename)
63+
return {}
64+
65+
def load(self):
66+
logger.debug("load: %s", self.filename)
67+
try:
68+
self.data = self._load()
69+
except Exception as ex: # pylint: disable=broad-exception-caught
70+
# If we still get exception after retry, ignore all types of exceptions and use a new cache.
71+
# - EOFError is caused by empty cache file created by other az instance, but hasn't been filled yet.
72+
# - KeyError is caused by reading cache generated by different MSAL versions.
73+
logger.debug("Failed to load cache: %s. Using a fresh one.", ex)
74+
self.data = {} # Ignore a non-existing or corrupted http_cache
75+
76+
@retry()
77+
def _save(self):
78+
with open(self.filename, 'w', encoding='utf-8') as f:
79+
# At this point, an empty cache file will be created. Loading this cache file will
80+
# raise EOFError. This can be simulated by adding time.sleep(30) here.
81+
# So during loading, EOFError is ignored.
82+
json.dump(self.data, f, cls=NormalizedResponseJsonEncoder)
83+
84+
def save(self):
85+
logger.debug("save: %s", self.filename)
86+
# If 2 processes write at the same time, the cache will be corrupted,
87+
# but that is fine. Subsequent runs would reach eventual consistency.
88+
try:
89+
self._save()
90+
except TypeError as e:
91+
# If serialization fails, skip saving to avoid corrupting the cache file
92+
logger.debug("Failed to save cache due to TypeError: %s", e)
93+
94+
def get(self, key, default=None):
95+
return self.data.get(key, default)
96+
97+
def __getitem__(self, key):
98+
return self.data[key]
99+
100+
def __setitem__(self, key, value):
101+
self.data[key] = value
102+
self.save()
103+
104+
def __delitem__(self, key):
105+
del self.data[key]
106+
self.save()
107+
108+
def __iter__(self):
109+
return iter(self.data)
110+
111+
def __len__(self):
112+
return len(self.data)

0 commit comments

Comments
 (0)