Skip to content

Commit cab41f3

Browse files
committed
Merge branch 'http_session_injected' into dev
2 parents abd1394 + f856858 commit cab41f3

File tree

3 files changed

+106
-20
lines changed

3 files changed

+106
-20
lines changed

oauth2cli/oauth2.py

Lines changed: 72 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""This OAuth2 client implementation aims to be spec-compliant, and generic."""
22
# OAuth2 spec https://tools.ietf.org/html/rfc6749
33

4+
import json
45
try:
56
from urllib.parse import urlencode, parse_qs
67
except ImportError:
@@ -11,6 +12,7 @@
1112
import time
1213
import base64
1314
import sys
15+
import functools
1416

1517
import requests
1618

@@ -35,12 +37,13 @@ def __init__(
3537
self,
3638
server_configuration, # type: dict
3739
client_id, # type: str
40+
http_client=None, # We insert it here to match the upcoming async API
3841
client_secret=None, # type: Optional[str]
3942
client_assertion=None, # type: Union[bytes, callable, None]
4043
client_assertion_type=None, # type: Optional[str]
4144
default_headers=None, # type: Optional[dict]
4245
default_body=None, # type: Optional[dict]
43-
verify=True, # type: Union[str, True, False, None]
46+
verify=None, # type: Union[str, True, False, None]
4447
proxies=None, # type: Optional[dict]
4548
timeout=None, # type: Union[tuple, float, None]
4649
):
@@ -57,6 +60,21 @@ def __init__(
5760
or
5861
https://example.com/.../.well-known/openid-configuration
5962
client_id (str): The client's id, issued by the authorization server
63+
64+
http_client (http.HttpClient):
65+
Your implementation of abstract class :class:`http.HttpClient`.
66+
Defaults to a requests session instance.
67+
68+
There is no session-wide `timeout` parameter defined here.
69+
Timeout behavior is determined by the actual http client you use.
70+
If you happen to use Requests, it disallows session-wide timeout
71+
(https://github.com/psf/requests/issues/3341). The workaround is:
72+
73+
s = requests.Session()
74+
s.request = functools.partial(s.request, timeout=3)
75+
76+
and then feed that patched session instance to this class.
77+
6078
client_secret (str): Triggers HTTP AUTH for Confidential Client
6179
client_assertion (bytes, callable):
6280
The client assertion to authenticate this client, per RFC 7521.
@@ -76,20 +94,52 @@ def __init__(
7694
you could choose to set this as {"client_secret": "your secret"}
7795
if your authorization server wants it to be in the request body
7896
(rather than in the request header).
97+
98+
verify (boolean):
99+
It will be passed to the
100+
`verify parameter in the underlying requests library
101+
<http://docs.python-requests.org/en/v2.9.1/user/advanced/#ssl-cert-verification>`_.
102+
When leaving it with default value (None), we will use True instead.
103+
104+
This does not apply if you have chosen to pass your own Http client.
105+
106+
proxies (dict):
107+
It will be passed to the
108+
`proxies parameter in the underlying requests library
109+
<http://docs.python-requests.org/en/v2.9.1/user/advanced/#proxies>`_.
110+
111+
This does not apply if you have chosen to pass your own Http client.
112+
113+
timeout (object):
114+
It will be passed to the
115+
`timeout parameter in the underlying requests library
116+
<http://docs.python-requests.org/en/v2.9.1/user/advanced/#timeouts>`_.
117+
118+
This does not apply if you have chosen to pass your own Http client.
119+
79120
"""
80121
self.configuration = server_configuration
81122
self.client_id = client_id
82123
self.client_secret = client_secret
83124
self.client_assertion = client_assertion
125+
self.default_headers = default_headers or {}
84126
self.default_body = default_body or {}
85127
if client_assertion_type is not None:
86128
self.default_body["client_assertion_type"] = client_assertion_type
87129
self.logger = logging.getLogger(__name__)
88-
self.session = s = requests.Session()
89-
s.headers.update(default_headers or {})
90-
s.verify = verify
91-
s.proxies = proxies or {}
92-
self.timeout = timeout
130+
if http_client:
131+
if verify is not None or proxies is not None or timeout is not None:
132+
raise ValueError(
133+
"verify, proxies, or timeout is not allowed "
134+
"when http_client is in use")
135+
self.http_client = http_client
136+
else:
137+
self.http_client = requests.Session()
138+
self.http_client.verify = True if verify is None else verify
139+
self.http_client.proxies = proxies
140+
self.http_client.request = functools.partial(
141+
# A workaround for requests not supporting session-wide timeout
142+
self.http_client.request, timeout=timeout)
93143

94144
def _build_auth_request_params(self, response_type, **kwargs):
95145
# response_type is a string defined in
@@ -110,7 +160,6 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
110160
params=None, # a dict to be sent as query string to the endpoint
111161
data=None, # All relevant data, which will go into the http body
112162
headers=None, # a dict to be sent as request headers
113-
timeout=None,
114163
post=None, # A callable to replace requests.post(), for testing.
115164
# Such as: lambda url, **kwargs:
116165
# Mock(status_code=200, json=Mock(return_value={}))
@@ -128,38 +177,40 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
128177

129178
_data.update(self.default_body) # It may contain authen parameters
130179
_data.update(data or {}) # So the content in data param prevails
131-
# We don't have to clean up None values here, because requests lib will.
180+
_data = {k: v for k, v in _data.items() if v} # Clean up None values
132181

133182
if _data.get('scope'):
134183
_data['scope'] = self._stringify(_data['scope'])
135184

185+
_headers = {'Accept': 'application/json'}
186+
_headers.update(self.default_headers)
187+
_headers.update(headers or {})
188+
136189
# Quoted from https://tools.ietf.org/html/rfc6749#section-2.3.1
137190
# Clients in possession of a client password MAY use the HTTP Basic
138191
# authentication.
139192
# Alternatively, (but NOT RECOMMENDED,)
140193
# the authorization server MAY support including the
141194
# client credentials in the request-body using the following
142195
# parameters: client_id, client_secret.
143-
auth = None
144196
if self.client_secret and self.client_id:
145-
auth = (self.client_id, self.client_secret) # for HTTP Basic Auth
197+
_headers["Authorization"] = "Basic " + base64.b64encode(
198+
"{}:{}".format(self.client_id, self.client_secret)
199+
.encode("ascii")).decode("ascii")
146200

147201
if "token_endpoint" not in self.configuration:
148202
raise ValueError("token_endpoint not found in configuration")
149-
_headers = {'Accept': 'application/json'}
150-
_headers.update(headers or {})
151-
resp = (post or self.session.post)(
203+
resp = (post or self.http_client.post)(
152204
self.configuration["token_endpoint"],
153-
headers=_headers, params=params, data=_data, auth=auth,
154-
timeout=timeout or self.timeout,
205+
headers=_headers, params=params, data=_data,
155206
**kwargs)
156207
if resp.status_code >= 500:
157208
resp.raise_for_status() # TODO: Will probably retry here
158209
try:
159210
# The spec (https://tools.ietf.org/html/rfc6749#section-5.2) says
160211
# even an error response will be a valid json structure,
161212
# so we simply return it here, without needing to invent an exception.
162-
return resp.json()
213+
return json.loads(resp.text)
163214
except ValueError:
164215
self.logger.exception(
165216
"Token response is not in json format: %s", resp.text)
@@ -200,7 +251,7 @@ class Client(BaseClient): # We choose to implement all 4 grants in 1 class
200251
grant_assertion_encoders = {GRANT_TYPE_SAML2: BaseClient.encode_saml_assertion}
201252

202253

203-
def initiate_device_flow(self, scope=None, timeout=None, **kwargs):
254+
def initiate_device_flow(self, scope=None, **kwargs):
204255
# type: (list, **dict) -> dict
205256
# The naming of this method is following the wording of this specs
206257
# https://tools.ietf.org/html/draft-ietf-oauth-device-flow-12#section-3.1
@@ -218,10 +269,11 @@ def initiate_device_flow(self, scope=None, timeout=None, **kwargs):
218269
DAE = "device_authorization_endpoint"
219270
if not self.configuration.get(DAE):
220271
raise ValueError("You need to provide device authorization endpoint")
221-
flow = self.session.post(self.configuration[DAE],
272+
resp = self.http_client.post(self.configuration[DAE],
222273
data={"client_id": self.client_id, "scope": self._stringify(scope or [])},
223-
timeout=timeout or self.timeout,
224-
**kwargs).json()
274+
headers=dict(self.default_headers, **kwargs.pop("headers", {})),
275+
**kwargs)
276+
flow = json.loads(resp.text)
225277
flow["interval"] = int(flow.get("interval", 5)) # Some IdP returns string
226278
flow["expires_in"] = int(flow.get("expires_in", 1800))
227279
flow["expires_at"] = time.time() + flow["expires_in"] # We invent this

tests/http_client.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import requests
2+
3+
4+
class MinimalHttpClient:
5+
6+
def __init__(self, verify=True, proxies=None, timeout=None):
7+
self.session = requests.Session()
8+
self.session.verify = verify
9+
self.session.proxies = proxies
10+
self.timeout = timeout
11+
12+
def post(self, url, params=None, data=None, headers=None, **kwargs):
13+
return MinimalResponse(requests_resp=self.session.post(
14+
url, params=params, data=data, headers=headers,
15+
timeout=self.timeout))
16+
17+
def get(self, url, params=None, headers=None, **kwargs):
18+
return MinimalResponse(requests_resp=self.session.get(
19+
url, params=params, headers=headers, timeout=self.timeout))
20+
21+
22+
class MinimalResponse(object): # Not for production use
23+
def __init__(self, requests_resp=None, status_code=None, text=None):
24+
self.status_code = status_code or requests_resp.status_code
25+
self.text = text or requests_resp.text
26+
self._raw_resp = requests_resp
27+
28+
def raise_for_status(self):
29+
if self._raw_resp:
30+
self._raw_resp.raise_for_status()

tests/test_client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from oauth2cli.authcode import obtain_auth_code
1414
from oauth2cli.assertion import JwtSigner
1515
from tests import unittest, Oauth2TestCase
16+
from tests.http_client import MinimalHttpClient
1617

1718

1819
logging.basicConfig(level=logging.DEBUG)
@@ -83,13 +84,15 @@ class TestClient(Oauth2TestCase):
8384

8485
@classmethod
8586
def setUpClass(cls):
87+
http_client = MinimalHttpClient()
8688
if "client_certificate" in CONFIG:
8789
private_key_path = CONFIG["client_certificate"]["private_key_path"]
8890
with open(os.path.join(THIS_FOLDER, private_key_path)) as f:
8991
private_key = f.read() # Expecting PEM format
9092
cls.client = Client(
9193
CONFIG["openid_configuration"],
9294
CONFIG['client_id'],
95+
http_client=http_client,
9396
client_assertion=JwtSigner(
9497
private_key,
9598
algorithm="RS256",
@@ -103,6 +106,7 @@ def setUpClass(cls):
103106
else:
104107
cls.client = Client(
105108
CONFIG["openid_configuration"], CONFIG['client_id'],
109+
http_client=http_client,
106110
client_secret=CONFIG.get('client_secret'))
107111

108112
@unittest.skipIf(

0 commit comments

Comments
 (0)