Skip to content

Commit 9079588

Browse files
committed
Implement verify, proxies, timeout based on requests.session
1 parent 595dcde commit 9079588

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

oauth2cli/oauth2.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ def __init__(
2727
client_assertion_type=None, # type: Optional[str]
2828
default_headers=None, # type: Optional[dict]
2929
default_body=None, # type: Optional[dict]
30+
verify=True, # type: Union[str, True, False, None]
31+
proxies=None, # type: Optional[dict]
32+
timeout=None, # type: Union[tuple, float, None]
3033
):
3134
"""Initialize a client object to talk all the OAuth2 grants to the server.
3235
@@ -62,16 +65,20 @@ def __init__(
6265
self.configuration = server_configuration
6366
self.client_id = client_id
6467
self.client_secret = client_secret
65-
self.default_headers = default_headers or {}
6668
self.default_body = default_body or {}
6769
if client_assertion is not None: # See https://tools.ietf.org/html/rfc7521#section-4.2
68-
TYPE_JWT = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
69-
TYPE_SAML2 = "urn:ietf:params:oauth:client-assertion-type:saml2-bearer"
7070
if client_assertion_type is None: # RFC7521 defines only 2 profiles
71+
TYPE_JWT = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
72+
TYPE_SAML2 = "urn:ietf:params:oauth:client-assertion-type:saml2-bearer"
7173
client_assertion_type = TYPE_JWT if "." in client_assertion else TYPE_SAML2
7274
self.default_body["client_assertion"] = client_assertion
7375
self.default_body["client_assertion_type"] = client_assertion_type
7476
self.logger = logging.getLogger(__name__)
77+
self.session = s = requests.Session()
78+
s.headers.update(default_headers or {})
79+
s.verify = verify
80+
s.proxies = proxies or {}
81+
self.timeout = timeout
7582

7683
def _build_auth_request_params(self, response_type, **kwargs):
7784
# response_type is a string defined in
@@ -92,6 +99,7 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
9299
params=None, # a dict to be sent as query string to the endpoint
93100
data=None, # All relevant data, which will go into the http body
94101
headers=None, # a dict to be sent as request headers
102+
timeout=None,
95103
**kwargs # Relay all extra parameters to underlying requests
96104
): # Returns the json object came from the OAUTH2 response
97105
_data = {'client_id': self.client_id, 'grant_type': grant_type}
@@ -116,11 +124,12 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
116124
if "token_endpoint" not in self.configuration:
117125
raise ValueError("token_endpoint not found in configuration")
118126
_headers = {'Accept': 'application/json'}
119-
_headers.update(self.default_headers)
120127
_headers.update(headers or {})
121-
resp = requests.post(
128+
resp = self.session.post(
122129
self.configuration["token_endpoint"],
123-
headers=_headers, params=params, data=_data, auth=auth, **kwargs)
130+
headers=_headers, params=params, data=_data, auth=auth,
131+
timeout=timeout or self.timeout,
132+
**kwargs)
124133
if resp.status_code >= 500:
125134
resp.raise_for_status() # TODO: Will probably retry here
126135
try:
@@ -164,7 +173,7 @@ class Client(BaseClient): # We choose to implement all 4 grants in 1 class
164173
GRANT_TYPE_SAML2 = "urn:ietf:params:oauth:grant-type:saml2-bearer" # RFC7522
165174
GRANT_TYPE_JWT = "urn:ietf:params:oauth:grant-type:jwt-bearer" # RFC7523
166175

167-
def initiate_device_flow(self, scope=None, **kwargs):
176+
def initiate_device_flow(self, scope=None, timeout=None, **kwargs):
168177
# type: (list, **dict) -> dict
169178
# The naming of this method is following the wording of this specs
170179
# https://tools.ietf.org/html/draft-ietf-oauth-device-flow-12#section-3.1
@@ -182,8 +191,9 @@ def initiate_device_flow(self, scope=None, **kwargs):
182191
DAE = "device_authorization_endpoint"
183192
if not self.configuration.get(DAE):
184193
raise ValueError("You need to provide device authorization endpoint")
185-
flow = requests.post(self.configuration[DAE], headers=self.default_headers,
194+
flow = self.session.post(self.configuration[DAE],
186195
data={"client_id": self.client_id, "scope": self._stringify(scope or [])},
196+
timeout=timeout or self.timeout,
187197
**kwargs).json()
188198
flow["interval"] = int(flow.get("interval", 5)) # Some IdP returns string
189199
flow["expires_in"] = int(flow.get("expires_in", 1800))

0 commit comments

Comments
 (0)