Skip to content

Commit 92d954d

Browse files
committed
Only cache desirable data in http cache
1 parent b92b4f1 commit 92d954d

File tree

4 files changed

+124
-31
lines changed

4 files changed

+124
-31
lines changed

msal/managed_identity.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ def __init__(self, *, client_id=None, resource_id=None, object_id=None):
112112

113113

114114
class _ThrottledHttpClient(ThrottledHttpClientBase):
115-
def __init__(self, http_client, **kwargs):
116-
super(_ThrottledHttpClient, self).__init__(http_client, **kwargs)
115+
def __init__(self, *args, **kwargs):
116+
super(_ThrottledHttpClient, self).__init__(*args, **kwargs)
117117
self.get = IndividualCache( # All MIs (except Cloud Shell) use GETs
118118
mapping=self._expiring_mapping,
119119
key_maker=lambda func, args, kwargs: "REQ {} hash={} 429/5xx/Retry-After".format(
@@ -124,7 +124,7 @@ def __init__(self, http_client, **kwargs):
124124
str(kwargs.get("params")) + str(kwargs.get("data"))),
125125
),
126126
expires_in=RetryAfterParser(5).parse, # 5 seconds default for non-PCA
127-
)(http_client.get)
127+
)(self.get)
128128

129129

130130
class ManagedIdentityClient(object):

msal/throttled_http_client.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from .individual_cache import _IndividualCache as IndividualCache
55
from .individual_cache import _ExpiringMapping as ExpiringMapping
6+
from .oauth2cli.http import Response
7+
from .exceptions import MsalServiceError
68

79

810
# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
@@ -37,6 +39,24 @@ def _extract_data(kwargs, key, default=None):
3739
return data.get(key) if isinstance(data, dict) else default
3840

3941

42+
class NormalizedResponse(Response):
43+
"""A http response with the shape defined in Response,
44+
but contains only the data we will store in cache.
45+
"""
46+
def __init__(self, raw_response):
47+
super().__init__()
48+
self.status_code = raw_response.status_code
49+
self.text = raw_response.text
50+
self.headers = raw_response.headers
51+
52+
## Note: Don't use the following line,
53+
## because when being pickled, it will indirectly pickle the whole raw_response
54+
# self.raise_for_status = raw_response.raise_for_status
55+
def raise_for_status(self):
56+
if self.status_code >= 400:
57+
raise MsalServiceError("HTTP Error: {}".format(self.status_code))
58+
59+
4060
class ThrottledHttpClientBase(object):
4161
"""Throttle the given http_client by storing and retrieving data from cache.
4262
@@ -54,10 +74,10 @@ def __init__(self, http_client, *, http_cache=None):
5474
)
5575

5676
def post(self, *args, **kwargs):
57-
return self.http_client.post(*args, **kwargs)
77+
return NormalizedResponse(self.http_client.post(*args, **kwargs))
5878

5979
def get(self, *args, **kwargs):
60-
return self.http_client.get(*args, **kwargs)
80+
return NormalizedResponse(self.http_client.get(*args, **kwargs))
6181

6282
def close(self):
6383
return self.http_client.close()
@@ -68,12 +88,11 @@ def _hash(raw):
6888

6989

7090
class ThrottledHttpClient(ThrottledHttpClientBase):
71-
def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
72-
super(ThrottledHttpClient, self).__init__(http_client, **kwargs)
73-
74-
_post = http_client.post # We'll patch _post, and keep original post() intact
75-
76-
_post = IndividualCache(
91+
"""A throttled http client wrapper that is tailored for MSAL."""
92+
def __init__(self, *args, default_throttle_time=None, **kwargs):
93+
"""Decorate self.post() and self.get() dynamically"""
94+
super(ThrottledHttpClient, self).__init__(*args, **kwargs)
95+
self.post = IndividualCache(
7796
# Internal specs requires throttling on at least token endpoint,
7897
# here we have a generic patch for POST on all endpoints.
7998
mapping=self._expiring_mapping,
@@ -91,9 +110,9 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
91110
_extract_data(kwargs, "username")))), # "account" of ROPC
92111
),
93112
expires_in=RetryAfterParser(default_throttle_time or 5).parse,
94-
)(_post)
113+
)(self.post)
95114

96-
_post = IndividualCache( # It covers the "UI required cache"
115+
self.post = IndividualCache( # It covers the "UI required cache"
97116
mapping=self._expiring_mapping,
98117
key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format(
99118
args[0], # It is the url, typically containing authority and tenant
@@ -128,9 +147,7 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
128147
and "retry-after" not in set( # Leave it to the Retry-After decorator
129148
h.lower() for h in getattr(result, "headers", {}).keys())
130149
else 0,
131-
)(_post)
132-
133-
self.post = _post
150+
)(self.post)
134151

135152
self.get = IndividualCache( # Typically those discovery GETs
136153
mapping=self._expiring_mapping,
@@ -140,9 +157,4 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
140157
),
141158
expires_in=lambda result=None, **ignored:
142159
3600*24 if 200 <= result.status_code < 300 else 0,
143-
)(http_client.get)
144-
145-
# The following 2 methods have been defined dynamically by __init__()
146-
#def post(self, *args, **kwargs): pass
147-
#def get(self, *args, **kwargs): pass
148-
160+
)(self.get)

tests/test_mi.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@
99
from mock import patch, ANY, mock_open, Mock
1010
import requests
1111

12-
from tests.http_client import MinimalResponse
12+
from tests.test_throttled_http_client import (
13+
MinimalResponse, ThrottledHttpClientBaseTestCase, DummyHttpClient)
1314
from msal import (
1415
SystemAssignedManagedIdentity, UserAssignedManagedIdentity,
1516
ManagedIdentityClient,
1617
ManagedIdentityError,
1718
ArcPlatformNotSupportedError,
1819
)
1920
from msal.managed_identity import (
21+
_ThrottledHttpClient,
2022
_supported_arc_platforms_and_their_prefixes,
2123
get_managed_identity_source,
2224
APP_SERVICE,
@@ -49,6 +51,37 @@ def test_helper_class_should_be_interchangable_with_dict_which_could_be_loaded_f
4951
{"ManagedIdentityIdType": "SystemAssigned", "Id": None})
5052

5153

54+
class ThrottledHttpClientTestCase(ThrottledHttpClientBaseTestCase):
55+
def test_throttled_http_client_should_not_alter_original_http_client(self):
56+
self.assertNotAlteringOriginalHttpClient(_ThrottledHttpClient)
57+
58+
def test_throttled_http_client_should_not_cache_successful_http_response(self):
59+
http_cache = {}
60+
http_client=DummyHttpClient(
61+
status_code=200,
62+
response_text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
63+
)
64+
app = ManagedIdentityClient(
65+
SystemAssignedManagedIdentity(), http_client=http_client, http_cache=http_cache)
66+
result = app.acquire_token_for_client(resource="R")
67+
self.assertEqual("AT", result["access_token"])
68+
self.assertEqual({}, http_cache, "Should not cache successful http response")
69+
70+
def test_throttled_http_client_should_cache_unsuccessful_http_response(self):
71+
http_cache = {}
72+
http_client=DummyHttpClient(
73+
status_code=400,
74+
response_headers={"Retry-After": "1"},
75+
response_text='{"error": "invalid_request"}',
76+
)
77+
app = ManagedIdentityClient(
78+
SystemAssignedManagedIdentity(), http_client=http_client, http_cache=http_cache)
79+
result = app.acquire_token_for_client(resource="R")
80+
self.assertEqual("invalid_request", result["error"])
81+
self.assertNotEqual({}, http_cache, "Should cache unsuccessful http response")
82+
self.assertCleanPickle(http_cache)
83+
84+
5285
class ClientTestCase(unittest.TestCase):
5386
maxDiff = None
5487

tests/test_throttled_http_client.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,43 @@
11
# Test cases for https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview&anchor=common-test-cases
2+
import pickle
23
from time import sleep
34
from random import random
45
import logging
5-
from msal.throttled_http_client import ThrottledHttpClient
6+
7+
from msal.throttled_http_client import (
8+
ThrottledHttpClientBase, ThrottledHttpClient, NormalizedResponse)
9+
610
from tests import unittest
7-
from tests.http_client import MinimalResponse
11+
from tests.http_client import MinimalResponse as _MinimalResponse
812

913

1014
logger = logging.getLogger(__name__)
1115
logging.basicConfig(level=logging.DEBUG)
1216

1317

18+
class MinimalResponse(_MinimalResponse):
19+
SIGNATURE = str(random()).encode("utf-8")
20+
21+
def __init__(self, *args, **kwargs):
22+
super().__init__(*args, **kwargs)
23+
self._ = ( # Only an instance attribute will be stored in pickled instance
24+
self.__class__.SIGNATURE) # Useful for testing its presence in pickled instance
25+
26+
1427
class DummyHttpClient(object):
15-
def __init__(self, status_code=None, response_headers=None):
28+
def __init__(self, status_code=None, response_headers=None, response_text=None):
1629
self._status_code = status_code
1730
self._response_headers = response_headers
31+
self._response_text = response_text
1832

1933
def _build_dummy_response(self):
2034
return MinimalResponse(
2135
status_code=self._status_code,
2236
headers=self._response_headers,
23-
text=random(), # So that we'd know whether a new response is received
24-
)
37+
text=self._response_text if self._response_text is not None else str(
38+
random() # So that we'd know whether a new response is received
39+
),
40+
)
2541

2642
def post(self, url, params=None, data=None, headers=None, **kwargs):
2743
return self._build_dummy_response()
@@ -37,19 +53,43 @@ class CloseMethodCalled(Exception):
3753
pass
3854

3955

40-
class TestHttpDecoration(unittest.TestCase):
56+
class ThrottledHttpClientBaseTestCase(unittest.TestCase):
4157

42-
def test_throttled_http_client_should_not_alter_original_http_client(self):
58+
def assertCleanPickle(self, obj):
59+
self.assertTrue(bool(obj), "The object should not be empty")
60+
self.assertNotIn(
61+
MinimalResponse.SIGNATURE, pickle.dumps(obj),
62+
"A pickled object should not contain undesirable data")
63+
64+
def assertValidResponse(self, response):
65+
self.assertIsInstance(response, NormalizedResponse)
66+
self.assertCleanPickle(response)
67+
68+
def test_pickled_minimal_response_should_contain_signature(self):
69+
self.assertIn(MinimalResponse.SIGNATURE, pickle.dumps(MinimalResponse(
70+
status_code=200, headers={}, text="foo")))
71+
72+
def test_throttled_http_client_base_response_should_not_contain_signature(self):
73+
http_client = ThrottledHttpClientBase(DummyHttpClient(status_code=200))
74+
response = http_client.post("https://example.com")
75+
self.assertValidResponse(response)
76+
77+
def assertNotAlteringOriginalHttpClient(self, ThrottledHttpClientClass):
4378
original_http_client = DummyHttpClient()
4479
original_get = original_http_client.get
4580
original_post = original_http_client.post
46-
throttled_http_client = ThrottledHttpClient(original_http_client)
81+
throttled_http_client = ThrottledHttpClientClass(original_http_client)
4782
goal = """The implementation should wrap original http_client
4883
and keep it intact, instead of monkey-patching it"""
4984
self.assertNotEqual(throttled_http_client, original_http_client, goal)
5085
self.assertEqual(original_post, original_http_client.post)
5186
self.assertEqual(original_get, original_http_client.get)
5287

88+
class ThrottledHttpClientTestCase(ThrottledHttpClientBaseTestCase):
89+
90+
def test_throttled_http_client_should_not_alter_original_http_client(self):
91+
self.assertNotAlteringOriginalHttpClient(ThrottledHttpClient)
92+
5393
def _test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(
5494
self, http_client, retry_after):
5595
http_cache = {}
@@ -112,15 +152,23 @@ def test_one_invalid_grant_should_block_a_similar_request(self):
112152
http_client = DummyHttpClient(
113153
status_code=400) # It covers invalid_grant and interaction_required
114154
http_client = ThrottledHttpClient(http_client, http_cache=http_cache)
155+
115156
resp1 = http_client.post("https://example.com", data={"claims": "foo"})
116157
logger.debug(http_cache)
158+
self.assertValidResponse(resp1)
117159
resp1_again = http_client.post("https://example.com", data={"claims": "foo"})
160+
self.assertValidResponse(resp1_again)
118161
self.assertEqual(resp1.text, resp1_again.text, "Should return a cached response")
162+
119163
resp2 = http_client.post("https://example.com", data={"claims": "bar"})
164+
self.assertValidResponse(resp2)
120165
self.assertNotEqual(resp1.text, resp2.text, "Should return a new response")
121166
resp2_again = http_client.post("https://example.com", data={"claims": "bar"})
167+
self.assertValidResponse(resp2_again)
122168
self.assertEqual(resp2.text, resp2_again.text, "Should return a cached response")
123169

170+
self.assertCleanPickle(http_cache)
171+
124172
def test_one_foci_app_recovering_from_invalid_grant_should_also_unblock_another(self):
125173
"""
126174
Need not test multiple FOCI app's acquire_token_silent() here. By design,

0 commit comments

Comments
 (0)