Skip to content

Commit 37d9af5

Browse files
authored
Merge pull request #816 from AzureAD/release-1.32.3
Merge release 1.32.3 back to dev branch
2 parents 1321e37 + dd4fe69 commit 37d9af5

File tree

9 files changed

+242
-55
lines changed

9 files changed

+242
-55
lines changed

msal/application.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,7 @@ def __init__(
506506
except (
507507
FileNotFoundError, # Or IOError in Python 2
508508
pickle.UnpicklingError, # A corrupted http cache file
509+
AttributeError, # Cache created by a different version of MSAL
509510
):
510511
persisted_http_cache = {} # Recover by starting afresh
511512
atexit.register(lambda: pickle.dump(

msal/individual_cache.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ def __init__(self, mapping=None, capacity=None, expires_in=None, lock=None,
5959
self._expires_in = expires_in
6060
self._lock = Lock() if lock is None else lock
6161

62+
def _peek(self):
63+
# Returns (sequence, timestamps) without triggering maintenance
64+
return self._mapping.get(self._INDEX, ([], {}))
65+
6266
def _validate_key(self, key):
6367
if key == self._INDEX:
6468
raise ValueError("key {} is a reserved keyword in {}".format(
@@ -85,7 +89,7 @@ def _set(self, key, value, expires_in):
8589
# This internal implementation powers both set() and __setitem__(),
8690
# so that they don't depend on each other.
8791
self._validate_key(key)
88-
sequence, timestamps = self._mapping.get(self._INDEX, ([], {}))
92+
sequence, timestamps = self._peek()
8993
self._maintenance(sequence, timestamps) # O(logN)
9094
now = int(time.time())
9195
expires_at = now + expires_in
@@ -136,7 +140,7 @@ def __getitem__(self, key): # O(1)
136140
self._validate_key(key)
137141
with self._lock:
138142
# Skip self._maintenance(), because it would need O(logN) time
139-
sequence, timestamps = self._mapping.get(self._INDEX, ([], {}))
143+
sequence, timestamps = self._peek()
140144
expires_at, created_at = timestamps[key] # Would raise KeyError accordingly
141145
now = int(time.time())
142146
if not created_at <= now < expires_at:
@@ -155,22 +159,22 @@ def __delitem__(self, key): # O(1)
155159
with self._lock:
156160
# Skip self._maintenance(), because it would need O(logN) time
157161
self._mapping.pop(key, None) # O(1)
158-
sequence, timestamps = self._mapping.get(self._INDEX, ([], {}))
162+
sequence, timestamps = self._peek()
159163
del timestamps[key] # O(1)
160164
self._mapping[self._INDEX] = sequence, timestamps
161165

162166
def __len__(self): # O(logN)
163167
"""Drop all expired items and return the remaining length"""
164168
with self._lock:
165-
sequence, timestamps = self._mapping.get(self._INDEX, ([], {}))
169+
sequence, timestamps = self._peek()
166170
self._maintenance(sequence, timestamps) # O(logN)
167171
self._mapping[self._INDEX] = sequence, timestamps
168172
return len(timestamps) # Faster than iter(self._mapping) when it is on disk
169173

170174
def __iter__(self):
171175
"""Drop all expired items and return an iterator of the remaining items"""
172176
with self._lock:
173-
sequence, timestamps = self._mapping.get(self._INDEX, ([], {}))
177+
sequence, timestamps = self._peek()
174178
self._maintenance(sequence, timestamps) # O(logN)
175179
self._mapping[self._INDEX] = sequence, timestamps
176180
return iter(timestamps) # Faster than iter(self._mapping) when it is on disk

msal/managed_identity.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ def __init__(self, *, client_id=None, resource_id=None, object_id=None):
113113

114114

115115
class _ThrottledHttpClient(ThrottledHttpClientBase):
116-
def __init__(self, http_client, **kwargs):
117-
super(_ThrottledHttpClient, self).__init__(http_client, **kwargs)
116+
def __init__(self, *args, **kwargs):
117+
super(_ThrottledHttpClient, self).__init__(*args, **kwargs)
118118
self.get = IndividualCache( # All MIs (except Cloud Shell) use GETs
119119
mapping=self._expiring_mapping,
120120
key_maker=lambda func, args, kwargs: "REQ {} hash={} 429/5xx/Retry-After".format(
@@ -125,7 +125,7 @@ def __init__(self, http_client, **kwargs):
125125
str(kwargs.get("params")) + str(kwargs.get("data"))),
126126
),
127127
expires_in=RetryAfterParser(5).parse, # 5 seconds default for non-PCA
128-
)(http_client.get)
128+
)(self.get) # Note: Decorate the parent get(), not the http_client.get()
129129

130130

131131
class ManagedIdentityClient(object):
@@ -246,8 +246,7 @@ def __init__(
246246
# (especially for 410 which was supposed to be a permanent failure).
247247
# 2. MI on Service Fabric specifically suggests to not retry on 404.
248248
# ( https://learn.microsoft.com/en-us/azure/service-fabric/how-to-managed-cluster-managed-identity-service-fabric-app-code#error-handling )
249-
http_client.http_client # Patch the raw (unpatched) http client
250-
if isinstance(http_client, ThrottledHttpClientBase) else http_client,
249+
http_client,
251250
http_cache=http_cache,
252251
)
253252
self._token_cache = token_cache or TokenCache()

msal/sku.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
"""
33

44
# The __init__.py will import this. Not the other way around.
5-
__version__ = "1.32.0"
5+
__version__ = "1.32.3"
66
SKU = "MSAL.Python"

msal/throttled_http_client.py

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,34 @@
33

44
from .individual_cache import _IndividualCache as IndividualCache
55
from .individual_cache import _ExpiringMapping as ExpiringMapping
6+
from .oauth2cli.http import Response
7+
from .exceptions import MsalServiceError
68

79

810
# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
911
DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
1012

1113

14+
def _get_headers(response):
15+
# MSAL's HttpResponse did not have headers until 1.23.0
16+
# https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/581/files#diff-28866b706bc3830cd20485685f20fe79d45b58dce7050e68032e9d9372d68654R61
17+
# This helper ensures graceful degradation to {} without exception
18+
return getattr(response, "headers", {})
19+
20+
1221
class RetryAfterParser(object):
22+
FIELD_NAME_LOWER = "Retry-After".lower()
1323
def __init__(self, default_value=None):
1424
self._default_value = 5 if default_value is None else default_value
1525

1626
def parse(self, *, result, **ignored):
1727
"""Return seconds to throttle"""
1828
response = result
19-
lowercase_headers = {k.lower(): v for k, v in getattr(
20-
# Historically, MSAL's HttpResponse does not always have headers
21-
response, "headers", {}).items()}
29+
lowercase_headers = {k.lower(): v for k, v in _get_headers(response).items()}
2230
if not (response.status_code == 429 or response.status_code >= 500
23-
or "retry-after" in lowercase_headers):
31+
or self.FIELD_NAME_LOWER in lowercase_headers):
2432
return 0 # Quick exit
25-
retry_after = lowercase_headers.get("retry-after", self._default_value)
33+
retry_after = lowercase_headers.get(self.FIELD_NAME_LOWER, self._default_value)
2634
try:
2735
# AAD's retry_after uses integer format only
2836
# https://stackoverflow.microsoft.com/questions/264931/264932
@@ -37,27 +45,55 @@ def _extract_data(kwargs, key, default=None):
3745
return data.get(key) if isinstance(data, dict) else default
3846

3947

48+
class NormalizedResponse(Response):
49+
"""A http response with the shape defined in Response,
50+
but contains only the data we will store in cache.
51+
"""
52+
def __init__(self, raw_response):
53+
super().__init__()
54+
self.status_code = raw_response.status_code
55+
self.text = raw_response.text
56+
self.headers = {
57+
k.lower(): v for k, v in _get_headers(raw_response).items()
58+
# Attempted storing only a small set of headers (such as Retry-After),
59+
# but it tends to lead to missing information (such as WWW-Authenticate).
60+
# So we store all headers, which are expected to contain only public info,
61+
# because we throttle only error responses and public responses.
62+
}
63+
64+
## Note: Don't use the following line,
65+
## because when being pickled, it will indirectly pickle the whole raw_response
66+
# self.raise_for_status = raw_response.raise_for_status
67+
def raise_for_status(self):
68+
if self.status_code >= 400:
69+
raise MsalServiceError("HTTP Error: {}".format(self.status_code))
70+
71+
4072
class ThrottledHttpClientBase(object):
4173
"""Throttle the given http_client by storing and retrieving data from cache.
4274
43-
This wrapper exists so that our patching post() and get() would prevent
44-
re-patching side effect when/if same http_client being reused.
75+
This base exists so that:
76+
1. These base post() and get() will return a NormalizedResponse
77+
2. The base __init__() will NOT re-throttle even if caller accidentally nested ThrottledHttpClient.
4578
46-
The subclass should implement post() and/or get()
79+
Subclasses shall only need to dynamically decorate their post() and get() methods
80+
in their __init__() method.
4781
"""
4882
def __init__(self, http_client, *, http_cache=None):
49-
self.http_client = http_client
83+
self.http_client = http_client.http_client if isinstance(
84+
# If it is already a ThrottledHttpClientBase, we use its raw (unthrottled) http client
85+
http_client, ThrottledHttpClientBase) else http_client
5086
self._expiring_mapping = ExpiringMapping( # It will automatically clean up
5187
mapping=http_cache if http_cache is not None else {},
5288
capacity=1024, # To prevent cache blowing up especially for CCA
5389
lock=Lock(), # TODO: This should ideally also allow customization
5490
)
5591

5692
def post(self, *args, **kwargs):
57-
return self.http_client.post(*args, **kwargs)
93+
return NormalizedResponse(self.http_client.post(*args, **kwargs))
5894

5995
def get(self, *args, **kwargs):
60-
return self.http_client.get(*args, **kwargs)
96+
return NormalizedResponse(self.http_client.get(*args, **kwargs))
6197

6298
def close(self):
6399
return self.http_client.close()
@@ -68,12 +104,11 @@ def _hash(raw):
68104

69105

70106
class ThrottledHttpClient(ThrottledHttpClientBase):
71-
def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
72-
super(ThrottledHttpClient, self).__init__(http_client, **kwargs)
73-
74-
_post = http_client.post # We'll patch _post, and keep original post() intact
75-
76-
_post = IndividualCache(
107+
"""A throttled http client that is used by MSAL's non-managed identity clients."""
108+
def __init__(self, *args, default_throttle_time=None, **kwargs):
109+
"""Decorate self.post() and self.get() dynamically"""
110+
super(ThrottledHttpClient, self).__init__(*args, **kwargs)
111+
self.post = IndividualCache(
77112
# Internal specs requires throttling on at least token endpoint,
78113
# here we have a generic patch for POST on all endpoints.
79114
mapping=self._expiring_mapping,
@@ -91,9 +126,9 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
91126
_extract_data(kwargs, "username")))), # "account" of ROPC
92127
),
93128
expires_in=RetryAfterParser(default_throttle_time or 5).parse,
94-
)(_post)
129+
)(self.post)
95130

96-
_post = IndividualCache( # It covers the "UI required cache"
131+
self.post = IndividualCache( # It covers the "UI required cache"
97132
mapping=self._expiring_mapping,
98133
key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format(
99134
args[0], # It is the url, typically containing authority and tenant
@@ -125,12 +160,10 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
125160
isinstance(kwargs.get("data"), dict)
126161
and kwargs["data"].get("grant_type") == DEVICE_AUTH_GRANT
127162
)
128-
and "retry-after" not in set( # Leave it to the Retry-After decorator
129-
h.lower() for h in getattr(result, "headers", {}).keys())
163+
and RetryAfterParser.FIELD_NAME_LOWER not in set( # Otherwise leave it to the Retry-After decorator
164+
h.lower() for h in _get_headers(result))
130165
else 0,
131-
)(_post)
132-
133-
self.post = _post
166+
)(self.post)
134167

135168
self.get = IndividualCache( # Typically those discovery GETs
136169
mapping=self._expiring_mapping,
@@ -140,9 +173,4 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
140173
),
141174
expires_in=lambda result=None, **ignored:
142175
3600*24 if 200 <= result.status_code < 300 else 0,
143-
)(http_client.get)
144-
145-
# The following 2 methods have been defined dynamically by __init__()
146-
#def post(self, *args, **kwargs): pass
147-
#def get(self, *args, **kwargs): pass
148-
176+
)(self.get)

tests/test_individual_cache.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88
class TestExpiringMapping(unittest.TestCase):
99
def setUp(self):
1010
self.mapping = {}
11-
self.m = ExpiringMapping(mapping=self.mapping, capacity=2, expires_in=1)
11+
self.expires_in = 1
12+
self.m = ExpiringMapping(
13+
mapping=self.mapping, capacity=2, expires_in=self.expires_in)
14+
15+
def how_many(self):
16+
# This helper checks how many items are in the mapping, WITHOUT triggering purge
17+
return len(self.m._peek()[1])
1218

1319
def test_should_disallow_accessing_reserved_keyword(self):
1420
with self.assertRaises(ValueError):
@@ -40,11 +46,21 @@ def test_iter_should_purge(self):
4046
sleep(1)
4147
self.assertEqual([], list(self.m))
4248

43-
def test_get_should_purge(self):
49+
def test_get_should_not_purge_and_should_return_only_when_the_item_is_still_valid(self):
4450
self.m["thing one"] = "one"
51+
self.m["thing two"] = "two"
4552
sleep(1)
53+
self.assertEqual(2, self.how_many(), "We begin with 2 items")
4654
with self.assertRaises(KeyError):
4755
self.m["thing one"]
56+
self.assertEqual(1, self.how_many(), "get() should not purge the remaining items")
57+
58+
def test_setitem_should_purge(self):
59+
self.m["thing one"] = "one"
60+
sleep(1)
61+
self.m["thing two"] = "two"
62+
self.assertEqual(1, self.how_many(), "setitem() should purge all expired items")
63+
self.assertEqual("two", self.m["thing two"], "The remaining item should be thing two")
4864

4965
def test_various_expiring_time(self):
5066
self.assertEqual(0, len(self.m))
@@ -57,12 +73,13 @@ def test_various_expiring_time(self):
5773
def test_old_item_can_be_updated_with_new_expiry_time(self):
5874
self.assertEqual(0, len(self.m))
5975
self.m["thing"] = "one"
60-
self.m.set("thing", "two", 2)
76+
new_lifetime = 3 # 2-second seems too short and causes flakiness
77+
self.m.set("thing", "two", new_lifetime)
6178
self.assertEqual(1, len(self.m), "It contains 1 item")
6279
self.assertEqual("two", self.m["thing"], 'Already been updated to "two"')
63-
sleep(1)
80+
sleep(self.expires_in)
6481
self.assertEqual("two", self.m["thing"], "Not yet expires")
65-
sleep(1)
82+
sleep(new_lifetime - self.expires_in)
6683
self.assertEqual(0, len(self.m))
6784

6885
def test_oversized_input_should_purge_most_aging_item(self):

tests/test_mi.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@
1111
from mock import patch, ANY, mock_open, Mock
1212
import requests
1313

14-
from tests.http_client import MinimalResponse
14+
from tests.test_throttled_http_client import (
15+
MinimalResponse, ThrottledHttpClientBaseTestCase, DummyHttpClient)
1516
from msal import (
1617
SystemAssignedManagedIdentity, UserAssignedManagedIdentity,
1718
ManagedIdentityClient,
1819
ManagedIdentityError,
1920
ArcPlatformNotSupportedError,
2021
)
2122
from msal.managed_identity import (
23+
_ThrottledHttpClient,
2224
_supported_arc_platforms_and_their_prefixes,
2325
get_managed_identity_source,
2426
APP_SERVICE,
@@ -51,6 +53,37 @@ def test_helper_class_should_be_interchangable_with_dict_which_could_be_loaded_f
5153
{"ManagedIdentityIdType": "SystemAssigned", "Id": None})
5254

5355

56+
class ThrottledHttpClientTestCase(ThrottledHttpClientBaseTestCase):
57+
def test_throttled_http_client_should_not_alter_original_http_client(self):
58+
self.assertNotAlteringOriginalHttpClient(_ThrottledHttpClient)
59+
60+
def test_throttled_http_client_should_not_cache_successful_http_response(self):
61+
http_cache = {}
62+
http_client=DummyHttpClient(
63+
status_code=200,
64+
response_text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
65+
)
66+
app = ManagedIdentityClient(
67+
SystemAssignedManagedIdentity(), http_client=http_client, http_cache=http_cache)
68+
result = app.acquire_token_for_client(resource="R")
69+
self.assertEqual("AT", result["access_token"])
70+
self.assertEqual({}, http_cache, "Should not cache successful http response")
71+
72+
def test_throttled_http_client_should_cache_unsuccessful_http_response(self):
73+
http_cache = {}
74+
http_client=DummyHttpClient(
75+
status_code=400,
76+
response_headers={"Retry-After": "1"},
77+
response_text='{"error": "invalid_request"}',
78+
)
79+
app = ManagedIdentityClient(
80+
SystemAssignedManagedIdentity(), http_client=http_client, http_cache=http_cache)
81+
result = app.acquire_token_for_client(resource="R")
82+
self.assertEqual("invalid_request", result["error"])
83+
self.assertNotEqual({}, http_cache, "Should cache unsuccessful http response")
84+
self.assertCleanPickle(http_cache)
85+
86+
5487
class ClientTestCase(unittest.TestCase):
5588
maxDiff = None
5689

0 commit comments

Comments
 (0)