diff --git a/msal/application.py b/msal/application.py index 25a0db2b..4310a1f6 100644 --- a/msal/application.py +++ b/msal/application.py @@ -499,6 +499,7 @@ def __init__( except ( FileNotFoundError, # Or IOError in Python 2 pickle.UnpicklingError, # A corrupted http cache file + AttributeError, # Cache created by a different version of MSAL ): persisted_http_cache = {} # Recover by starting afresh atexit.register(lambda: pickle.dump( diff --git a/msal/individual_cache.py b/msal/individual_cache.py index 4c6fa00e..34f275dd 100644 --- a/msal/individual_cache.py +++ b/msal/individual_cache.py @@ -59,6 +59,10 @@ def __init__(self, mapping=None, capacity=None, expires_in=None, lock=None, self._expires_in = expires_in self._lock = Lock() if lock is None else lock + def _peek(self): + # Returns (sequence, timestamps) without triggering maintenance + return self._mapping.get(self._INDEX, ([], {})) + def _validate_key(self, key): if key == self._INDEX: raise ValueError("key {} is a reserved keyword in {}".format( @@ -85,7 +89,7 @@ def _set(self, key, value, expires_in): # This internal implementation powers both set() and __setitem__(), # so that they don't depend on each other. self._validate_key(key) - sequence, timestamps = self._mapping.get(self._INDEX, ([], {})) + sequence, timestamps = self._peek() self._maintenance(sequence, timestamps) # O(logN) now = int(time.time()) expires_at = now + expires_in @@ -136,7 +140,7 @@ def __getitem__(self, key): # O(1) self._validate_key(key) with self._lock: # Skip self._maintenance(), because it would need O(logN) time - sequence, timestamps = self._mapping.get(self._INDEX, ([], {})) + sequence, timestamps = self._peek() expires_at, created_at = timestamps[key] # Would raise KeyError accordingly now = int(time.time()) if not created_at <= now < expires_at: @@ -155,14 +159,14 @@ def __delitem__(self, key): # O(1) with self._lock: # Skip self._maintenance(), because it would need O(logN) time self._mapping.pop(key, None) # O(1) - sequence, timestamps = self._mapping.get(self._INDEX, ([], {})) + sequence, timestamps = self._peek() del timestamps[key] # O(1) self._mapping[self._INDEX] = sequence, timestamps def __len__(self): # O(logN) """Drop all expired items and return the remaining length""" with self._lock: - sequence, timestamps = self._mapping.get(self._INDEX, ([], {})) + sequence, timestamps = self._peek() self._maintenance(sequence, timestamps) # O(logN) self._mapping[self._INDEX] = sequence, timestamps return len(timestamps) # Faster than iter(self._mapping) when it is on disk @@ -170,7 +174,7 @@ def __len__(self): # O(logN) def __iter__(self): """Drop all expired items and return an iterator of the remaining items""" with self._lock: - sequence, timestamps = self._mapping.get(self._INDEX, ([], {})) + sequence, timestamps = self._peek() self._maintenance(sequence, timestamps) # O(logN) self._mapping[self._INDEX] = sequence, timestamps return iter(timestamps) # Faster than iter(self._mapping) when it is on disk diff --git a/msal/managed_identity.py b/msal/managed_identity.py index 6f85571d..692bb7ad 100644 --- a/msal/managed_identity.py +++ b/msal/managed_identity.py @@ -112,8 +112,8 @@ def __init__(self, *, client_id=None, resource_id=None, object_id=None): class _ThrottledHttpClient(ThrottledHttpClientBase): - def __init__(self, http_client, **kwargs): - super(_ThrottledHttpClient, self).__init__(http_client, **kwargs) + def __init__(self, *args, **kwargs): + super(_ThrottledHttpClient, self).__init__(*args, **kwargs) self.get = IndividualCache( # All MIs (except Cloud Shell) use GETs mapping=self._expiring_mapping, key_maker=lambda func, args, kwargs: "REQ {} hash={} 429/5xx/Retry-After".format( @@ -124,7 +124,7 @@ def __init__(self, http_client, **kwargs): str(kwargs.get("params")) + str(kwargs.get("data"))), ), expires_in=RetryAfterParser(5).parse, # 5 seconds default for non-PCA - )(http_client.get) + )(self.get) # Note: Decorate the parent get(), not the http_client.get() class ManagedIdentityClient(object): @@ -233,8 +233,7 @@ def __init__( # (especially for 410 which was supposed to be a permanent failure). # 2. MI on Service Fabric specifically suggests to not retry on 404. # ( https://learn.microsoft.com/en-us/azure/service-fabric/how-to-managed-cluster-managed-identity-service-fabric-app-code#error-handling ) - http_client.http_client # Patch the raw (unpatched) http client - if isinstance(http_client, ThrottledHttpClientBase) else http_client, + http_client, http_cache=http_cache, ) self._token_cache = token_cache or TokenCache() diff --git a/msal/sku.py b/msal/sku.py index 2a3172aa..1f38aeb7 100644 --- a/msal/sku.py +++ b/msal/sku.py @@ -2,5 +2,5 @@ """ # The __init__.py will import this. Not the other way around. -__version__ = "1.32.0" +__version__ = "1.32.1" SKU = "MSAL.Python" diff --git a/msal/throttled_http_client.py b/msal/throttled_http_client.py index ebad76c7..fd0c9ad5 100644 --- a/msal/throttled_http_client.py +++ b/msal/throttled_http_client.py @@ -3,6 +3,8 @@ from .individual_cache import _IndividualCache as IndividualCache from .individual_cache import _ExpiringMapping as ExpiringMapping +from .oauth2cli.http import Response +from .exceptions import MsalServiceError # https://datatracker.ietf.org/doc/html/rfc8628#section-3.4 @@ -10,6 +12,7 @@ class RetryAfterParser(object): + FIELD_NAME_LOWER = "Retry-After".lower() def __init__(self, default_value=None): self._default_value = 5 if default_value is None else default_value @@ -20,9 +23,9 @@ def parse(self, *, result, **ignored): # Historically, MSAL's HttpResponse does not always have headers response, "headers", {}).items()} if not (response.status_code == 429 or response.status_code >= 500 - or "retry-after" in lowercase_headers): + or self.FIELD_NAME_LOWER in lowercase_headers): return 0 # Quick exit - retry_after = lowercase_headers.get("retry-after", self._default_value) + retry_after = lowercase_headers.get(self.FIELD_NAME_LOWER, self._default_value) try: # AAD's retry_after uses integer format only # https://stackoverflow.microsoft.com/questions/264931/264932 @@ -37,16 +40,41 @@ def _extract_data(kwargs, key, default=None): return data.get(key) if isinstance(data, dict) else default +class NormalizedResponse(Response): + """A http response with the shape defined in Response, + but contains only the data we will store in cache. + """ + def __init__(self, raw_response): + super().__init__() + self.status_code = raw_response.status_code + self.text = raw_response.text + self.headers = { # Only keep the headers which ThrottledHttpClient cares about + k: v for k, v in raw_response.headers.items() + if k.lower() == RetryAfterParser.FIELD_NAME_LOWER + } + + ## Note: Don't use the following line, + ## because when being pickled, it will indirectly pickle the whole raw_response + # self.raise_for_status = raw_response.raise_for_status + def raise_for_status(self): + if self.status_code >= 400: + raise MsalServiceError("HTTP Error: {}".format(self.status_code)) + + class ThrottledHttpClientBase(object): """Throttle the given http_client by storing and retrieving data from cache. - This wrapper exists so that our patching post() and get() would prevent - re-patching side effect when/if same http_client being reused. + This base exists so that: + 1. These base post() and get() will return a NormalizedResponse + 2. The base __init__() will NOT re-throttle even if caller accidentally nested ThrottledHttpClient. - The subclass should implement post() and/or get() + Subclasses shall only need to dynamically decorate their post() and get() methods + in their __init__() method. """ def __init__(self, http_client, *, http_cache=None): - self.http_client = http_client + self.http_client = http_client.http_client if isinstance( + # If it is already a ThrottledHttpClientBase, we use its raw (unthrottled) http client + http_client, ThrottledHttpClientBase) else http_client self._expiring_mapping = ExpiringMapping( # It will automatically clean up mapping=http_cache if http_cache is not None else {}, capacity=1024, # To prevent cache blowing up especially for CCA @@ -54,10 +82,10 @@ def __init__(self, http_client, *, http_cache=None): ) def post(self, *args, **kwargs): - return self.http_client.post(*args, **kwargs) + return NormalizedResponse(self.http_client.post(*args, **kwargs)) def get(self, *args, **kwargs): - return self.http_client.get(*args, **kwargs) + return NormalizedResponse(self.http_client.get(*args, **kwargs)) def close(self): return self.http_client.close() @@ -68,12 +96,11 @@ def _hash(raw): class ThrottledHttpClient(ThrottledHttpClientBase): - def __init__(self, http_client, *, default_throttle_time=None, **kwargs): - super(ThrottledHttpClient, self).__init__(http_client, **kwargs) - - _post = http_client.post # We'll patch _post, and keep original post() intact - - _post = IndividualCache( + """A throttled http client that is used by MSAL's non-managed identity clients.""" + def __init__(self, *args, default_throttle_time=None, **kwargs): + """Decorate self.post() and self.get() dynamically""" + super(ThrottledHttpClient, self).__init__(*args, **kwargs) + self.post = IndividualCache( # Internal specs requires throttling on at least token endpoint, # here we have a generic patch for POST on all endpoints. mapping=self._expiring_mapping, @@ -91,9 +118,9 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs): _extract_data(kwargs, "username")))), # "account" of ROPC ), expires_in=RetryAfterParser(default_throttle_time or 5).parse, - )(_post) + )(self.post) - _post = IndividualCache( # It covers the "UI required cache" + self.post = IndividualCache( # It covers the "UI required cache" mapping=self._expiring_mapping, key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format( args[0], # It is the url, typically containing authority and tenant @@ -125,12 +152,10 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs): isinstance(kwargs.get("data"), dict) and kwargs["data"].get("grant_type") == DEVICE_AUTH_GRANT ) - and "retry-after" not in set( # Leave it to the Retry-After decorator + and RetryAfterParser.FIELD_NAME_LOWER not in set( # Otherwise leave it to the Retry-After decorator h.lower() for h in getattr(result, "headers", {}).keys()) else 0, - )(_post) - - self.post = _post + )(self.post) self.get = IndividualCache( # Typically those discovery GETs mapping=self._expiring_mapping, @@ -140,9 +165,4 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs): ), expires_in=lambda result=None, **ignored: 3600*24 if 200 <= result.status_code < 300 else 0, - )(http_client.get) - - # The following 2 methods have been defined dynamically by __init__() - #def post(self, *args, **kwargs): pass - #def get(self, *args, **kwargs): pass - + )(self.get) diff --git a/tests/test_individual_cache.py b/tests/test_individual_cache.py index 38bd572d..ce4aa993 100644 --- a/tests/test_individual_cache.py +++ b/tests/test_individual_cache.py @@ -8,7 +8,13 @@ class TestExpiringMapping(unittest.TestCase): def setUp(self): self.mapping = {} - self.m = ExpiringMapping(mapping=self.mapping, capacity=2, expires_in=1) + self.expires_in = 1 + self.m = ExpiringMapping( + mapping=self.mapping, capacity=2, expires_in=self.expires_in) + + def how_many(self): + # This helper checks how many items are in the mapping, WITHOUT triggering purge + return len(self.m._peek()[1]) def test_should_disallow_accessing_reserved_keyword(self): with self.assertRaises(ValueError): @@ -40,11 +46,21 @@ def test_iter_should_purge(self): sleep(1) self.assertEqual([], list(self.m)) - def test_get_should_purge(self): + def test_get_should_not_purge_and_should_return_only_when_the_item_is_still_valid(self): self.m["thing one"] = "one" + self.m["thing two"] = "two" sleep(1) + self.assertEqual(2, self.how_many(), "We begin with 2 items") with self.assertRaises(KeyError): self.m["thing one"] + self.assertEqual(1, self.how_many(), "get() should not purge the remaining items") + + def test_setitem_should_purge(self): + self.m["thing one"] = "one" + sleep(1) + self.m["thing two"] = "two" + self.assertEqual(1, self.how_many(), "setitem() should purge all expired items") + self.assertEqual("two", self.m["thing two"], "The remaining item should be thing two") def test_various_expiring_time(self): self.assertEqual(0, len(self.m)) @@ -57,12 +73,13 @@ def test_various_expiring_time(self): def test_old_item_can_be_updated_with_new_expiry_time(self): self.assertEqual(0, len(self.m)) self.m["thing"] = "one" - self.m.set("thing", "two", 2) + new_lifetime = 3 # 2-second seems too short and causes flakiness + self.m.set("thing", "two", new_lifetime) self.assertEqual(1, len(self.m), "It contains 1 item") self.assertEqual("two", self.m["thing"], 'Already been updated to "two"') - sleep(1) + sleep(self.expires_in) self.assertEqual("two", self.m["thing"], "Not yet expires") - sleep(1) + sleep(new_lifetime - self.expires_in) self.assertEqual(0, len(self.m)) def test_oversized_input_should_purge_most_aging_item(self): diff --git a/tests/test_mi.py b/tests/test_mi.py index a7c2cb6c..e065899c 100644 --- a/tests/test_mi.py +++ b/tests/test_mi.py @@ -9,7 +9,8 @@ from mock import patch, ANY, mock_open, Mock import requests -from tests.http_client import MinimalResponse +from tests.test_throttled_http_client import ( + MinimalResponse, ThrottledHttpClientBaseTestCase, DummyHttpClient) from msal import ( SystemAssignedManagedIdentity, UserAssignedManagedIdentity, ManagedIdentityClient, @@ -17,6 +18,7 @@ ArcPlatformNotSupportedError, ) from msal.managed_identity import ( + _ThrottledHttpClient, _supported_arc_platforms_and_their_prefixes, get_managed_identity_source, APP_SERVICE, @@ -49,6 +51,37 @@ def test_helper_class_should_be_interchangable_with_dict_which_could_be_loaded_f {"ManagedIdentityIdType": "SystemAssigned", "Id": None}) +class ThrottledHttpClientTestCase(ThrottledHttpClientBaseTestCase): + def test_throttled_http_client_should_not_alter_original_http_client(self): + self.assertNotAlteringOriginalHttpClient(_ThrottledHttpClient) + + def test_throttled_http_client_should_not_cache_successful_http_response(self): + http_cache = {} + http_client=DummyHttpClient( + status_code=200, + response_text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}', + ) + app = ManagedIdentityClient( + SystemAssignedManagedIdentity(), http_client=http_client, http_cache=http_cache) + result = app.acquire_token_for_client(resource="R") + self.assertEqual("AT", result["access_token"]) + self.assertEqual({}, http_cache, "Should not cache successful http response") + + def test_throttled_http_client_should_cache_unsuccessful_http_response(self): + http_cache = {} + http_client=DummyHttpClient( + status_code=400, + response_headers={"Retry-After": "1"}, + response_text='{"error": "invalid_request"}', + ) + app = ManagedIdentityClient( + SystemAssignedManagedIdentity(), http_client=http_client, http_cache=http_cache) + result = app.acquire_token_for_client(resource="R") + self.assertEqual("invalid_request", result["error"]) + self.assertNotEqual({}, http_cache, "Should cache unsuccessful http response") + self.assertCleanPickle(http_cache) + + class ClientTestCase(unittest.TestCase): maxDiff = None diff --git a/tests/test_throttled_http_client.py b/tests/test_throttled_http_client.py index 3994719d..c15a7877 100644 --- a/tests/test_throttled_http_client.py +++ b/tests/test_throttled_http_client.py @@ -1,27 +1,43 @@ # Test cases for https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview&anchor=common-test-cases +import pickle from time import sleep from random import random import logging -from msal.throttled_http_client import ThrottledHttpClient + +from msal.throttled_http_client import ( + ThrottledHttpClientBase, ThrottledHttpClient, NormalizedResponse) + from tests import unittest -from tests.http_client import MinimalResponse +from tests.http_client import MinimalResponse as _MinimalResponse logger = logging.getLogger(__name__) logging.basicConfig(level=logging.DEBUG) +class MinimalResponse(_MinimalResponse): + SIGNATURE = str(random()).encode("utf-8") + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._ = ( # Only an instance attribute will be stored in pickled instance + self.__class__.SIGNATURE) # Useful for testing its presence in pickled instance + + class DummyHttpClient(object): - def __init__(self, status_code=None, response_headers=None): + def __init__(self, status_code=None, response_headers=None, response_text=None): self._status_code = status_code self._response_headers = response_headers + self._response_text = response_text def _build_dummy_response(self): return MinimalResponse( status_code=self._status_code, headers=self._response_headers, - text=random(), # So that we'd know whether a new response is received - ) + text=self._response_text if self._response_text is not None else str( + random() # So that we'd know whether a new response is received + ), + ) def post(self, url, params=None, data=None, headers=None, **kwargs): return self._build_dummy_response() @@ -37,19 +53,54 @@ class CloseMethodCalled(Exception): pass -class TestHttpDecoration(unittest.TestCase): +class ThrottledHttpClientBaseTestCase(unittest.TestCase): - def test_throttled_http_client_should_not_alter_original_http_client(self): + def assertCleanPickle(self, obj): + self.assertTrue(bool(obj), "The object should not be empty") + self.assertNotIn( + MinimalResponse.SIGNATURE, pickle.dumps(obj), + "A pickled object should not contain undesirable data") + + def assertValidResponse(self, response): + self.assertIsInstance(response, NormalizedResponse) + self.assertCleanPickle(response) + + def test_pickled_minimal_response_should_contain_signature(self): + self.assertIn(MinimalResponse.SIGNATURE, pickle.dumps(MinimalResponse( + status_code=200, headers={}, text="foo"))) + + def test_throttled_http_client_base_response_should_not_contain_signature(self): + http_client = ThrottledHttpClientBase(DummyHttpClient(status_code=200)) + response = http_client.post("https://example.com") + self.assertValidResponse(response) + + def assertNotAlteringOriginalHttpClient(self, ThrottledHttpClientClass): original_http_client = DummyHttpClient() original_get = original_http_client.get original_post = original_http_client.post - throttled_http_client = ThrottledHttpClient(original_http_client) + throttled_http_client = ThrottledHttpClientClass(original_http_client) goal = """The implementation should wrap original http_client and keep it intact, instead of monkey-patching it""" self.assertNotEqual(throttled_http_client, original_http_client, goal) self.assertEqual(original_post, original_http_client.post) self.assertEqual(original_get, original_http_client.get) + def test_throttled_http_client_base_should_not_alter_original_http_client(self): + self.assertNotAlteringOriginalHttpClient(ThrottledHttpClientBase) + + def test_throttled_http_client_base_should_not_nest_http_client(self): + original_http_client = DummyHttpClient() + throttled_http_client = ThrottledHttpClientBase(original_http_client) + self.assertIs(original_http_client, throttled_http_client.http_client) + nested_throttled_http_client = ThrottledHttpClientBase(throttled_http_client) + self.assertIs(original_http_client, nested_throttled_http_client.http_client) + + +class ThrottledHttpClientTestCase(ThrottledHttpClientBaseTestCase): + + def test_throttled_http_client_should_not_alter_original_http_client(self): + self.assertNotAlteringOriginalHttpClient(ThrottledHttpClient) + def _test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds( self, http_client, retry_after): http_cache = {} @@ -112,15 +163,23 @@ def test_one_invalid_grant_should_block_a_similar_request(self): http_client = DummyHttpClient( status_code=400) # It covers invalid_grant and interaction_required http_client = ThrottledHttpClient(http_client, http_cache=http_cache) + resp1 = http_client.post("https://example.com", data={"claims": "foo"}) logger.debug(http_cache) + self.assertValidResponse(resp1) resp1_again = http_client.post("https://example.com", data={"claims": "foo"}) + self.assertValidResponse(resp1_again) self.assertEqual(resp1.text, resp1_again.text, "Should return a cached response") + resp2 = http_client.post("https://example.com", data={"claims": "bar"}) + self.assertValidResponse(resp2) self.assertNotEqual(resp1.text, resp2.text, "Should return a new response") resp2_again = http_client.post("https://example.com", data={"claims": "bar"}) + self.assertValidResponse(resp2_again) self.assertEqual(resp2.text, resp2_again.text, "Should return a cached response") + self.assertCleanPickle(http_cache) + def test_one_foci_app_recovering_from_invalid_grant_should_also_unblock_another(self): """ Need not test multiple FOCI app's acquire_token_silent() here. By design, diff --git a/tox.ini b/tox.ini new file mode 100644 index 00000000..dffd5110 --- /dev/null +++ b/tox.ini @@ -0,0 +1,28 @@ +[tox] +env_list = + py3 +minversion = 4.21.2 + +[testenv] +description = run the tests with pytest +package = wheel +wheel_build_env = .pkg +passenv = + # This allows tox environment on a DevBox to trigger host browser + DISPLAY +deps = + pytest>=6 + -r requirements.txt +commands = + pip list + {posargs:pytest --color=yes} + +[testenv:azcli] +deps = + azure-cli +commands_pre = + # It will unfortunately be run every time but luckily subsequent runs are fast. + pip install -e . +commands = + pip list + {posargs:az --version}