Skip to content

Commit 586a8ad

Browse files
committed
Merge remote-tracking branch 'oauth2cli/dev' into bar
2 parents 685b14b + 6d8647e commit 586a8ad

File tree

2 files changed

+53
-21
lines changed

2 files changed

+53
-21
lines changed

msal/oauth2cli/oauth2.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,17 @@ def encode_saml_assertion(assertion):
3333
CLIENT_ASSERTION_TYPE_SAML2 = "urn:ietf:params:oauth:client-assertion-type:saml2-bearer"
3434
client_assertion_encoders = {CLIENT_ASSERTION_TYPE_SAML2: encode_saml_assertion}
3535

36+
@property
37+
def session(self):
38+
warnings.warn("Will be gone in next major release", DeprecationWarning)
39+
return self._http_client
40+
41+
@session.setter
42+
def session(self, value):
43+
warnings.warn("Will be gone in next major release", DeprecationWarning)
44+
self._http_client = value
45+
46+
3647
def __init__(
3748
self,
3849
server_configuration, # type: dict
@@ -43,7 +54,7 @@ def __init__(
4354
client_assertion_type=None, # type: Optional[str]
4455
default_headers=None, # type: Optional[dict]
4556
default_body=None, # type: Optional[dict]
46-
verify=True, # type: Union[str, True, False, None]
57+
verify=None, # type: Union[str, True, False, None]
4758
proxies=None, # type: Optional[dict]
4859
timeout=None, # type: Union[tuple, float, None]
4960
):
@@ -60,9 +71,21 @@ def __init__(
6071
or
6172
https://example.com/.../.well-known/openid-configuration
6273
client_id (str): The client's id, issued by the authorization server
74+
6375
http_client (http.HttpClient):
6476
Your implementation of abstract class :class:`http.HttpClient`.
6577
Defaults to a requests session instance.
78+
79+
There is no session-wide `timeout` parameter defined here.
80+
Timeout behavior is determined by the actual http client you use.
81+
If you happen to use Requests, it disallows session-wide timeout
82+
(https://github.com/psf/requests/issues/3341). The workaround is:
83+
84+
s = requests.Session()
85+
s.request = functools.partial(s.request, timeout=3)
86+
87+
and then feed that patched session instance to this class.
88+
6689
client_secret (str): Triggers HTTP AUTH for Confidential Client
6790
client_assertion (bytes, callable):
6891
The client assertion to authenticate this client, per RFC 7521.
@@ -86,28 +109,25 @@ def __init__(
86109
verify (boolean):
87110
It will be passed to the
88111
`verify parameter in the underlying requests library
89-
<http://docs.python-requests.org/en/v2.9.1/user/advanced/#ssl-cert-verification>`_
112+
<http://docs.python-requests.org/en/v2.9.1/user/advanced/#ssl-cert-verification>`_.
113+
When leaving it with default value (None), we will use True instead.
114+
90115
This does not apply if you have chosen to pass your own Http client.
116+
91117
proxies (dict):
92118
It will be passed to the
93119
`proxies parameter in the underlying requests library
94-
<http://docs.python-requests.org/en/v2.9.1/user/advanced/#proxies>`_
120+
<http://docs.python-requests.org/en/v2.9.1/user/advanced/#proxies>`_.
121+
95122
This does not apply if you have chosen to pass your own Http client.
123+
96124
timeout (object):
97125
It will be passed to the
98126
`timeout parameter in the underlying requests library
99-
<http://docs.python-requests.org/en/v2.9.1/user/advanced/#timeouts>`_
100-
This does not apply if you have chosen to pass your own Http client.
101-
102-
There is no session-wide `timeout` parameter defined here.
103-
The timeout behavior is determined by the actual http client you use.
104-
If you happen to use Requests, it chose to not support session-wide timeout
105-
(https://github.com/psf/requests/issues/3341), but you can patch that by:
127+
<http://docs.python-requests.org/en/v2.9.1/user/advanced/#timeouts>`_.
106128
107-
s = requests.Session()
108-
s.request = functools.partial(s.request, timeout=3)
129+
This does not apply if you have chosen to pass your own Http client.
109130
110-
and then feed that patched session instance to this class.
111131
"""
112132
self.configuration = server_configuration
113133
self.client_id = client_id
@@ -119,14 +139,18 @@ def __init__(
119139
self.default_body["client_assertion_type"] = client_assertion_type
120140
self.logger = logging.getLogger(__name__)
121141
if http_client:
122-
self.http_client = http_client
142+
if verify is not None or proxies is not None or timeout is not None:
143+
raise ValueError(
144+
"verify, proxies, or timeout is not allowed "
145+
"when http_client is in use")
146+
self._http_client = http_client
123147
else:
124-
self.http_client = requests.Session()
125-
self.http_client.verify = verify
126-
self.http_client.proxies = proxies
127-
self.http_client.request = functools.partial(
148+
self._http_client = requests.Session()
149+
self._http_client.verify = True if verify is None else verify
150+
self._http_client.proxies = proxies
151+
self._http_client.request = functools.partial(
128152
# A workaround for requests not supporting session-wide timeout
129-
self.http_client.request, timeout=timeout)
153+
self._http_client.request, timeout=timeout)
130154

131155
def _build_auth_request_params(self, response_type, **kwargs):
132156
# response_type is a string defined in
@@ -187,7 +211,7 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
187211

188212
if "token_endpoint" not in self.configuration:
189213
raise ValueError("token_endpoint not found in configuration")
190-
resp = (post or self.http_client.post)(
214+
resp = (post or self._http_client.post)(
191215
self.configuration["token_endpoint"],
192216
headers=_headers, params=params, data=_data,
193217
**kwargs)
@@ -256,7 +280,7 @@ def initiate_device_flow(self, scope=None, **kwargs):
256280
DAE = "device_authorization_endpoint"
257281
if not self.configuration.get(DAE):
258282
raise ValueError("You need to provide device authorization endpoint")
259-
resp = self.http_client.post(self.configuration[DAE],
283+
resp = self._http_client.post(self.configuration[DAE],
260284
data={"client_id": self.client_id, "scope": self._stringify(scope or [])},
261285
headers=dict(self.default_headers, **kwargs.pop("headers", {})),
262286
**kwargs)

tests/test_client.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,11 @@ def test_device_flow(self):
174174
assertion=lambda: self.assertIn('access_token', result),
175175
skippable_errors=self.client.DEVICE_FLOW_RETRIABLE_ERRORS)
176176

177+
178+
class TestSessionAccessibility(unittest.TestCase):
179+
def test_accessing_session_property_for_backward_compatibility(self):
180+
client = Client({}, "client_id")
181+
client.session
182+
client.session.close()
183+
client.session = "something"
184+

0 commit comments

Comments
 (0)