11"""This OAuth2 client implementation aims to be spec-compliant, and generic."""
22# OAuth2 spec https://tools.ietf.org/html/rfc6749
33
4+ import json
45try :
56 from urllib .parse import urlencode , parse_qs
67except ImportError :
1112import time
1213import base64
1314import sys
15+ import functools
1416
1517import requests
1618
@@ -35,12 +37,13 @@ def __init__(
3537 self ,
3638 server_configuration , # type: dict
3739 client_id , # type: str
40+ http_client = None , # We insert it here to match the upcoming async API
3841 client_secret = None , # type: Optional[str]
3942 client_assertion = None , # type: Union[bytes, callable, None]
4043 client_assertion_type = None , # type: Optional[str]
4144 default_headers = None , # type: Optional[dict]
4245 default_body = None , # type: Optional[dict]
43- verify = True , # type: Union[str, True, False, None]
46+ verify = None , # type: Union[str, True, False, None]
4447 proxies = None , # type: Optional[dict]
4548 timeout = None , # type: Union[tuple, float, None]
4649 ):
@@ -57,6 +60,21 @@ def __init__(
5760 or
5861 https://example.com/.../.well-known/openid-configuration
5962 client_id (str): The client's id, issued by the authorization server
63+
64+ http_client (http.HttpClient):
65+ Your implementation of abstract class :class:`http.HttpClient`.
66+ Defaults to a requests session instance.
67+
68+ There is no session-wide `timeout` parameter defined here.
69+ Timeout behavior is determined by the actual http client you use.
70+ If you happen to use Requests, it disallows session-wide timeout
71+ (https://github.com/psf/requests/issues/3341). The workaround is:
72+
73+ s = requests.Session()
74+ s.request = functools.partial(s.request, timeout=3)
75+
76+ and then feed that patched session instance to this class.
77+
6078 client_secret (str): Triggers HTTP AUTH for Confidential Client
6179 client_assertion (bytes, callable):
6280 The client assertion to authenticate this client, per RFC 7521.
@@ -76,20 +94,52 @@ def __init__(
7694 you could choose to set this as {"client_secret": "your secret"}
7795 if your authorization server wants it to be in the request body
7896 (rather than in the request header).
97+
98+ verify (boolean):
99+ It will be passed to the
100+ `verify parameter in the underlying requests library
101+ <http://docs.python-requests.org/en/v2.9.1/user/advanced/#ssl-cert-verification>`_.
102+ When leaving it with default value (None), we will use True instead.
103+
104+ This does not apply if you have chosen to pass your own Http client.
105+
106+ proxies (dict):
107+ It will be passed to the
108+ `proxies parameter in the underlying requests library
109+ <http://docs.python-requests.org/en/v2.9.1/user/advanced/#proxies>`_.
110+
111+ This does not apply if you have chosen to pass your own Http client.
112+
113+ timeout (object):
114+ It will be passed to the
115+ `timeout parameter in the underlying requests library
116+ <http://docs.python-requests.org/en/v2.9.1/user/advanced/#timeouts>`_.
117+
118+ This does not apply if you have chosen to pass your own Http client.
119+
79120 """
80121 self .configuration = server_configuration
81122 self .client_id = client_id
82123 self .client_secret = client_secret
83124 self .client_assertion = client_assertion
125+ self .default_headers = default_headers or {}
84126 self .default_body = default_body or {}
85127 if client_assertion_type is not None :
86128 self .default_body ["client_assertion_type" ] = client_assertion_type
87129 self .logger = logging .getLogger (__name__ )
88- self .session = s = requests .Session ()
89- s .headers .update (default_headers or {})
90- s .verify = verify
91- s .proxies = proxies or {}
92- self .timeout = timeout
130+ if http_client :
131+ if verify is not None or proxies is not None or timeout is not None :
132+ raise ValueError (
133+ "verify, proxies, or timeout is not allowed "
134+ "when http_client is in use" )
135+ self .http_client = http_client
136+ else :
137+ self .http_client = requests .Session ()
138+ self .http_client .verify = True if verify is None else verify
139+ self .http_client .proxies = proxies
140+ self .http_client .request = functools .partial (
141+ # A workaround for requests not supporting session-wide timeout
142+ self .http_client .request , timeout = timeout )
93143
94144 def _build_auth_request_params (self , response_type , ** kwargs ):
95145 # response_type is a string defined in
@@ -110,7 +160,6 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
110160 params = None , # a dict to be sent as query string to the endpoint
111161 data = None , # All relevant data, which will go into the http body
112162 headers = None , # a dict to be sent as request headers
113- timeout = None ,
114163 post = None , # A callable to replace requests.post(), for testing.
115164 # Such as: lambda url, **kwargs:
116165 # Mock(status_code=200, json=Mock(return_value={}))
@@ -128,38 +177,40 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
128177
129178 _data .update (self .default_body ) # It may contain authen parameters
130179 _data .update (data or {}) # So the content in data param prevails
131- # We don't have to clean up None values here, because requests lib will.
180+ _data = { k : v for k , v in _data . items () if v } # Clean up None values
132181
133182 if _data .get ('scope' ):
134183 _data ['scope' ] = self ._stringify (_data ['scope' ])
135184
185+ _headers = {'Accept' : 'application/json' }
186+ _headers .update (self .default_headers )
187+ _headers .update (headers or {})
188+
136189 # Quoted from https://tools.ietf.org/html/rfc6749#section-2.3.1
137190 # Clients in possession of a client password MAY use the HTTP Basic
138191 # authentication.
139192 # Alternatively, (but NOT RECOMMENDED,)
140193 # the authorization server MAY support including the
141194 # client credentials in the request-body using the following
142195 # parameters: client_id, client_secret.
143- auth = None
144196 if self .client_secret and self .client_id :
145- auth = (self .client_id , self .client_secret ) # for HTTP Basic Auth
197+ _headers ["Authorization" ] = "Basic " + base64 .b64encode (
198+ "{}:{}" .format (self .client_id , self .client_secret )
199+ .encode ("ascii" )).decode ("ascii" )
146200
147201 if "token_endpoint" not in self .configuration :
148202 raise ValueError ("token_endpoint not found in configuration" )
149- _headers = {'Accept' : 'application/json' }
150- _headers .update (headers or {})
151- resp = (post or self .session .post )(
203+ resp = (post or self .http_client .post )(
152204 self .configuration ["token_endpoint" ],
153- headers = _headers , params = params , data = _data , auth = auth ,
154- timeout = timeout or self .timeout ,
205+ headers = _headers , params = params , data = _data ,
155206 ** kwargs )
156207 if resp .status_code >= 500 :
157208 resp .raise_for_status () # TODO: Will probably retry here
158209 try :
159210 # The spec (https://tools.ietf.org/html/rfc6749#section-5.2) says
160211 # even an error response will be a valid json structure,
161212 # so we simply return it here, without needing to invent an exception.
162- return resp . json ( )
213+ return json . loads ( resp . text )
163214 except ValueError :
164215 self .logger .exception (
165216 "Token response is not in json format: %s" , resp .text )
@@ -200,7 +251,7 @@ class Client(BaseClient): # We choose to implement all 4 grants in 1 class
200251 grant_assertion_encoders = {GRANT_TYPE_SAML2 : BaseClient .encode_saml_assertion }
201252
202253
203- def initiate_device_flow (self , scope = None , timeout = None , ** kwargs ):
254+ def initiate_device_flow (self , scope = None , ** kwargs ):
204255 # type: (list, **dict) -> dict
205256 # The naming of this method is following the wording of this specs
206257 # https://tools.ietf.org/html/draft-ietf-oauth-device-flow-12#section-3.1
@@ -218,10 +269,11 @@ def initiate_device_flow(self, scope=None, timeout=None, **kwargs):
218269 DAE = "device_authorization_endpoint"
219270 if not self .configuration .get (DAE ):
220271 raise ValueError ("You need to provide device authorization endpoint" )
221- flow = self .session .post (self .configuration [DAE ],
272+ resp = self .http_client .post (self .configuration [DAE ],
222273 data = {"client_id" : self .client_id , "scope" : self ._stringify (scope or [])},
223- timeout = timeout or self .timeout ,
224- ** kwargs ).json ()
274+ headers = dict (self .default_headers , ** kwargs .pop ("headers" , {})),
275+ ** kwargs )
276+ flow = json .loads (resp .text )
225277 flow ["interval" ] = int (flow .get ("interval" , 5 )) # Some IdP returns string
226278 flow ["expires_in" ] = int (flow .get ("expires_in" , 1800 ))
227279 flow ["expires_at" ] = time .time () + flow ["expires_in" ] # We invent this
0 commit comments