Skip to content

Commit cb83fa8

Browse files
committed
Replace hardcoded session with http_client param
Remove timeout parameter, requests and httpx behaviors are incompatible anyway
1 parent abd1394 commit cb83fa8

File tree

3 files changed

+92
-19
lines changed

3 files changed

+92
-19
lines changed

oauth2cli/oauth2.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""This OAuth2 client implementation aims to be spec-compliant, and generic."""
22
# OAuth2 spec https://tools.ietf.org/html/rfc6749
33

4+
import json
45
try:
56
from urllib.parse import urlencode, parse_qs
67
except ImportError:
@@ -11,6 +12,7 @@
1112
import time
1213
import base64
1314
import sys
15+
import functools
1416

1517
import requests
1618

@@ -35,6 +37,7 @@ def __init__(
3537
self,
3638
server_configuration, # type: dict
3739
client_id, # type: str
40+
http_client=None, # We insert it here to match the upcoming async API
3841
client_secret=None, # type: Optional[str]
3942
client_assertion=None, # type: Union[bytes, callable, None]
4043
client_assertion_type=None, # type: Optional[str]
@@ -57,6 +60,9 @@ def __init__(
5760
or
5861
https://example.com/.../.well-known/openid-configuration
5962
client_id (str): The client's id, issued by the authorization server
63+
http_client (http.HttpClient):
64+
Your implementation of abstract class :class:`http.HttpClient`.
65+
Defaults to a requests session instance.
6066
client_secret (str): Triggers HTTP AUTH for Confidential Client
6167
client_assertion (bytes, callable):
6268
The client assertion to authenticate this client, per RFC 7521.
@@ -76,20 +82,51 @@ def __init__(
7682
you could choose to set this as {"client_secret": "your secret"}
7783
if your authorization server wants it to be in the request body
7884
(rather than in the request header).
85+
86+
verify (boolean):
87+
It will be passed to the
88+
`verify parameter in the underlying requests library
89+
<http://docs.python-requests.org/en/v2.9.1/user/advanced/#ssl-cert-verification>`_
90+
This does not apply if you have chosen to pass your own Http client.
91+
proxies (dict):
92+
It will be passed to the
93+
`proxies parameter in the underlying requests library
94+
<http://docs.python-requests.org/en/v2.9.1/user/advanced/#proxies>`_
95+
This does not apply if you have chosen to pass your own Http client.
96+
timeout (object):
97+
It will be passed to the
98+
`timeout parameter in the underlying requests library
99+
<http://docs.python-requests.org/en/v2.9.1/user/advanced/#timeouts>`_
100+
This does not apply if you have chosen to pass your own Http client.
101+
102+
There is no session-wide `timeout` parameter defined here.
103+
The timeout behavior is determined by the actual http client you use.
104+
If you happen to use Requests, it chose to not support session-wide timeout
105+
(https://github.com/psf/requests/issues/3341), but you can patch that by:
106+
107+
s = requests.Session()
108+
s.request = functools.partial(s.request, timeout=3)
109+
110+
and then feed that patched session instance to this class.
79111
"""
80112
self.configuration = server_configuration
81113
self.client_id = client_id
82114
self.client_secret = client_secret
83115
self.client_assertion = client_assertion
116+
self.default_headers = default_headers or {}
84117
self.default_body = default_body or {}
85118
if client_assertion_type is not None:
86119
self.default_body["client_assertion_type"] = client_assertion_type
87120
self.logger = logging.getLogger(__name__)
88-
self.session = s = requests.Session()
89-
s.headers.update(default_headers or {})
90-
s.verify = verify
91-
s.proxies = proxies or {}
92-
self.timeout = timeout
121+
if http_client:
122+
self.http_client = http_client
123+
else:
124+
self.http_client = requests.Session()
125+
self.http_client.verify = verify
126+
self.http_client.proxies = proxies
127+
self.http_client.request = functools.partial(
128+
# A workaround for requests not supporting session-wide timeout
129+
self.http_client.request, timeout=timeout)
93130

94131
def _build_auth_request_params(self, response_type, **kwargs):
95132
# response_type is a string defined in
@@ -110,7 +147,6 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
110147
params=None, # a dict to be sent as query string to the endpoint
111148
data=None, # All relevant data, which will go into the http body
112149
headers=None, # a dict to be sent as request headers
113-
timeout=None,
114150
post=None, # A callable to replace requests.post(), for testing.
115151
# Such as: lambda url, **kwargs:
116152
# Mock(status_code=200, json=Mock(return_value={}))
@@ -128,38 +164,40 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
128164

129165
_data.update(self.default_body) # It may contain authen parameters
130166
_data.update(data or {}) # So the content in data param prevails
131-
# We don't have to clean up None values here, because requests lib will.
167+
_data = {k: v for k, v in _data.items() if v} # Clean up None values
132168

133169
if _data.get('scope'):
134170
_data['scope'] = self._stringify(_data['scope'])
135171

172+
_headers = {'Accept': 'application/json'}
173+
_headers.update(self.default_headers)
174+
_headers.update(headers or {})
175+
136176
# Quoted from https://tools.ietf.org/html/rfc6749#section-2.3.1
137177
# Clients in possession of a client password MAY use the HTTP Basic
138178
# authentication.
139179
# Alternatively, (but NOT RECOMMENDED,)
140180
# the authorization server MAY support including the
141181
# client credentials in the request-body using the following
142182
# parameters: client_id, client_secret.
143-
auth = None
144183
if self.client_secret and self.client_id:
145-
auth = (self.client_id, self.client_secret) # for HTTP Basic Auth
184+
_headers["Authorization"] = "Basic " + base64.b64encode(
185+
"{}:{}".format(self.client_id, self.client_secret)
186+
.encode("ascii")).decode("ascii")
146187

147188
if "token_endpoint" not in self.configuration:
148189
raise ValueError("token_endpoint not found in configuration")
149-
_headers = {'Accept': 'application/json'}
150-
_headers.update(headers or {})
151-
resp = (post or self.session.post)(
190+
resp = (post or self.http_client.post)(
152191
self.configuration["token_endpoint"],
153-
headers=_headers, params=params, data=_data, auth=auth,
154-
timeout=timeout or self.timeout,
192+
headers=_headers, params=params, data=_data,
155193
**kwargs)
156194
if resp.status_code >= 500:
157195
resp.raise_for_status() # TODO: Will probably retry here
158196
try:
159197
# The spec (https://tools.ietf.org/html/rfc6749#section-5.2) says
160198
# even an error response will be a valid json structure,
161199
# so we simply return it here, without needing to invent an exception.
162-
return resp.json()
200+
return json.loads(resp.text)
163201
except ValueError:
164202
self.logger.exception(
165203
"Token response is not in json format: %s", resp.text)
@@ -200,7 +238,7 @@ class Client(BaseClient): # We choose to implement all 4 grants in 1 class
200238
grant_assertion_encoders = {GRANT_TYPE_SAML2: BaseClient.encode_saml_assertion}
201239

202240

203-
def initiate_device_flow(self, scope=None, timeout=None, **kwargs):
241+
def initiate_device_flow(self, scope=None, **kwargs):
204242
# type: (list, **dict) -> dict
205243
# The naming of this method is following the wording of this specs
206244
# https://tools.ietf.org/html/draft-ietf-oauth-device-flow-12#section-3.1
@@ -218,10 +256,11 @@ def initiate_device_flow(self, scope=None, timeout=None, **kwargs):
218256
DAE = "device_authorization_endpoint"
219257
if not self.configuration.get(DAE):
220258
raise ValueError("You need to provide device authorization endpoint")
221-
flow = self.session.post(self.configuration[DAE],
259+
resp = self.http_client.post(self.configuration[DAE],
222260
data={"client_id": self.client_id, "scope": self._stringify(scope or [])},
223-
timeout=timeout or self.timeout,
224-
**kwargs).json()
261+
headers=dict(self.default_headers, **kwargs.pop("headers", {})),
262+
**kwargs)
263+
flow = json.loads(resp.text)
225264
flow["interval"] = int(flow.get("interval", 5)) # Some IdP returns string
226265
flow["expires_in"] = int(flow.get("expires_in", 1800))
227266
flow["expires_at"] = time.time() + flow["expires_in"] # We invent this

tests/http_client.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import requests
2+
3+
4+
class MinimalHttpClient:
5+
6+
def __init__(self, verify=True, proxies=None, timeout=None):
7+
self.session = requests.Session()
8+
self.session.verify = verify
9+
self.session.proxies = proxies
10+
self.timeout = timeout
11+
12+
def post(self, url, params=None, data=None, headers=None, **kwargs):
13+
return MinimalResponse(requests_resp=self.session.post(
14+
url, params=params, data=data, headers=headers,
15+
timeout=self.timeout))
16+
17+
def get(self, url, params=None, headers=None, **kwargs):
18+
return MinimalResponse(requests_resp=self.session.get(
19+
url, params=params, headers=headers, timeout=self.timeout))
20+
21+
22+
class MinimalResponse(object): # Not for production use
23+
def __init__(self, requests_resp=None, status_code=None, text=None):
24+
self.status_code = status_code or requests_resp.status_code
25+
self.text = text or requests_resp.text
26+
self._raw_resp = requests_resp
27+
28+
def raise_for_status(self):
29+
if self._raw_resp:
30+
self._raw_resp.raise_for_status()

tests/test_client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from oauth2cli.authcode import obtain_auth_code
1414
from oauth2cli.assertion import JwtSigner
1515
from tests import unittest, Oauth2TestCase
16+
from tests.http_client import MinimalHttpClient
1617

1718

1819
logging.basicConfig(level=logging.DEBUG)
@@ -83,13 +84,15 @@ class TestClient(Oauth2TestCase):
8384

8485
@classmethod
8586
def setUpClass(cls):
87+
http_client = MinimalHttpClient()
8688
if "client_certificate" in CONFIG:
8789
private_key_path = CONFIG["client_certificate"]["private_key_path"]
8890
with open(os.path.join(THIS_FOLDER, private_key_path)) as f:
8991
private_key = f.read() # Expecting PEM format
9092
cls.client = Client(
9193
CONFIG["openid_configuration"],
9294
CONFIG['client_id'],
95+
http_client=http_client,
9396
client_assertion=JwtSigner(
9497
private_key,
9598
algorithm="RS256",
@@ -103,6 +106,7 @@ def setUpClass(cls):
103106
else:
104107
cls.client = Client(
105108
CONFIG["openid_configuration"], CONFIG['client_id'],
109+
http_client=http_client,
106110
client_secret=CONFIG.get('client_secret'))
107111

108112
@unittest.skipIf(

0 commit comments

Comments
 (0)