Skip to content

Commit 43c202c

Browse files
committed
Only cache desirable data in http cache
1 parent b92b4f1 commit 43c202c

File tree

2 files changed

+51
-15
lines changed

2 files changed

+51
-15
lines changed

msal/throttled_http_client.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from .individual_cache import _IndividualCache as IndividualCache
55
from .individual_cache import _ExpiringMapping as ExpiringMapping
6+
from .oauth2cli.http import Response
7+
from .exceptions import MsalServiceError
68

79

810
# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
@@ -67,13 +69,30 @@ def _hash(raw):
6769
return sha256(repr(raw).encode("utf-8")).hexdigest()
6870

6971

72+
class NormalizedResponse(Response):
73+
"""A http response with the shape defined in Response,
74+
but contains only the data we will store in cache.
75+
"""
76+
def __init__(self, raw_response):
77+
super().__init__()
78+
self.status_code = raw_response.status_code
79+
self.text = raw_response.text
80+
self.headers = raw_response.headers
81+
82+
## Note: Don't use the following line,
83+
## because when being pickled, it will indirectly pickle the whole raw_response
84+
# self.raise_for_status = raw_response.raise_for_status
85+
def raise_for_status(self):
86+
if self.status_code >= 400:
87+
raise MsalServiceError("HTTP Error: {}".format(self.status_code))
88+
89+
7090
class ThrottledHttpClient(ThrottledHttpClientBase):
91+
"""A throttled http client wrapper that is tailored for MSAL."""
7192
def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
93+
"""Decorate self.post() and self.get() dynamically"""
7294
super(ThrottledHttpClient, self).__init__(http_client, **kwargs)
73-
74-
_post = http_client.post # We'll patch _post, and keep original post() intact
75-
76-
_post = IndividualCache(
95+
self.post = IndividualCache(
7796
# Internal specs requires throttling on at least token endpoint,
7897
# here we have a generic patch for POST on all endpoints.
7998
mapping=self._expiring_mapping,
@@ -91,9 +110,9 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
91110
_extract_data(kwargs, "username")))), # "account" of ROPC
92111
),
93112
expires_in=RetryAfterParser(default_throttle_time or 5).parse,
94-
)(_post)
113+
)(self.post)
95114

96-
_post = IndividualCache( # It covers the "UI required cache"
115+
self.post = IndividualCache( # It covers the "UI required cache"
97116
mapping=self._expiring_mapping,
98117
key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format(
99118
args[0], # It is the url, typically containing authority and tenant
@@ -128,9 +147,7 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
128147
and "retry-after" not in set( # Leave it to the Retry-After decorator
129148
h.lower() for h in getattr(result, "headers", {}).keys())
130149
else 0,
131-
)(_post)
132-
133-
self.post = _post
150+
)(self.post)
134151

135152
self.get = IndividualCache( # Typically those discovery GETs
136153
mapping=self._expiring_mapping,
@@ -140,9 +157,10 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
140157
),
141158
expires_in=lambda result=None, **ignored:
142159
3600*24 if 200 <= result.status_code < 300 else 0,
143-
)(http_client.get)
160+
)(self.get)
144161

145-
# The following 2 methods have been defined dynamically by __init__()
146-
#def post(self, *args, **kwargs): pass
147-
#def get(self, *args, **kwargs): pass
162+
def post(self, *args, **kwargs):
163+
return NormalizedResponse(super(ThrottledHttpClient, self).post(*args, **kwargs))
148164

165+
def get(self, *args, **kwargs):
166+
return NormalizedResponse(super(ThrottledHttpClient, self).get(*args, **kwargs))

tests/test_throttled_http_client.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
# Test cases for https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview&anchor=common-test-cases
2+
import pickle
23
from time import sleep
34
from random import random
45
import logging
5-
from msal.throttled_http_client import ThrottledHttpClient
6+
7+
from msal.throttled_http_client import ThrottledHttpClient, NormalizedResponse
8+
69
from tests import unittest
710
from tests.http_client import MinimalResponse
811

@@ -12,16 +15,19 @@
1215

1316

1417
class DummyHttpClient(object):
18+
SIGNATURE = str(random()).encode("utf-8")
1519
def __init__(self, status_code=None, response_headers=None):
1620
self._status_code = status_code
1721
self._response_headers = response_headers
1822

1923
def _build_dummy_response(self):
20-
return MinimalResponse(
24+
response = MinimalResponse(
2125
status_code=self._status_code,
2226
headers=self._response_headers,
2327
text=random(), # So that we'd know whether a new response is received
2428
)
29+
response.undesirable_data = self.__class__.SIGNATURE # To be tested for its absence in pickled response
30+
return response
2531

2632
def post(self, url, params=None, data=None, headers=None, **kwargs):
2733
return self._build_dummy_response()
@@ -39,6 +45,12 @@ class CloseMethodCalled(Exception):
3945

4046
class TestHttpDecoration(unittest.TestCase):
4147

48+
def assertValidResponse(self, response):
49+
self.assertIsInstance(response, NormalizedResponse)
50+
self.assertNotIn(
51+
DummyHttpClient.SIGNATURE, pickle.dumps(response),
52+
"A pickled response should not contain undesirable data")
53+
4254
def test_throttled_http_client_should_not_alter_original_http_client(self):
4355
original_http_client = DummyHttpClient()
4456
original_get = original_http_client.get
@@ -112,13 +124,19 @@ def test_one_invalid_grant_should_block_a_similar_request(self):
112124
http_client = DummyHttpClient(
113125
status_code=400) # It covers invalid_grant and interaction_required
114126
http_client = ThrottledHttpClient(http_client, http_cache=http_cache)
127+
115128
resp1 = http_client.post("https://example.com", data={"claims": "foo"})
116129
logger.debug(http_cache)
130+
self.assertValidResponse(resp1)
117131
resp1_again = http_client.post("https://example.com", data={"claims": "foo"})
132+
self.assertValidResponse(resp1_again)
118133
self.assertEqual(resp1.text, resp1_again.text, "Should return a cached response")
134+
119135
resp2 = http_client.post("https://example.com", data={"claims": "bar"})
136+
self.assertValidResponse(resp2)
120137
self.assertNotEqual(resp1.text, resp2.text, "Should return a new response")
121138
resp2_again = http_client.post("https://example.com", data={"claims": "bar"})
139+
self.assertValidResponse(resp2_again)
122140
self.assertEqual(resp2.text, resp2_again.text, "Should return a cached response")
123141

124142
def test_one_foci_app_recovering_from_invalid_grant_should_also_unblock_another(self):

0 commit comments

Comments
 (0)