@@ -27,6 +27,9 @@ def __init__(
2727 client_assertion_type = None , # type: Optional[str]
2828 default_headers = None , # type: Optional[dict]
2929 default_body = None , # type: Optional[dict]
30+ verify = True , # type: Union[str, True, False, None]
31+ proxies = None , # type: Optional[dict]
32+ timeout = None , # type: Union[tuple, float, None]
3033 ):
3134 """Initialize a client object to talk all the OAuth2 grants to the server.
3235
@@ -62,16 +65,20 @@ def __init__(
6265 self .configuration = server_configuration
6366 self .client_id = client_id
6467 self .client_secret = client_secret
65- self .default_headers = default_headers or {}
6668 self .default_body = default_body or {}
6769 if client_assertion is not None : # See https://tools.ietf.org/html/rfc7521#section-4.2
68- TYPE_JWT = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
69- TYPE_SAML2 = "urn:ietf:params:oauth:client-assertion-type:saml2-bearer"
7070 if client_assertion_type is None : # RFC7521 defines only 2 profiles
71+ TYPE_JWT = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
72+ TYPE_SAML2 = "urn:ietf:params:oauth:client-assertion-type:saml2-bearer"
7173 client_assertion_type = TYPE_JWT if "." in client_assertion else TYPE_SAML2
7274 self .default_body ["client_assertion" ] = client_assertion
7375 self .default_body ["client_assertion_type" ] = client_assertion_type
7476 self .logger = logging .getLogger (__name__ )
77+ self .session = s = requests .Session ()
78+ s .headers .update (default_headers or {})
79+ s .verify = verify
80+ s .proxies = proxies or {}
81+ self .timeout = timeout
7582
7683 def _build_auth_request_params (self , response_type , ** kwargs ):
7784 # response_type is a string defined in
@@ -92,6 +99,7 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
9299 params = None , # a dict to be sent as query string to the endpoint
93100 data = None , # All relevant data, which will go into the http body
94101 headers = None , # a dict to be sent as request headers
102+ timeout = None ,
95103 ** kwargs # Relay all extra parameters to underlying requests
96104 ): # Returns the json object came from the OAUTH2 response
97105 _data = {'client_id' : self .client_id , 'grant_type' : grant_type }
@@ -116,11 +124,12 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
116124 if "token_endpoint" not in self .configuration :
117125 raise ValueError ("token_endpoint not found in configuration" )
118126 _headers = {'Accept' : 'application/json' }
119- _headers .update (self .default_headers )
120127 _headers .update (headers or {})
121- resp = requests .post (
128+ resp = self . session .post (
122129 self .configuration ["token_endpoint" ],
123- headers = _headers , params = params , data = _data , auth = auth , ** kwargs )
130+ headers = _headers , params = params , data = _data , auth = auth ,
131+ timeout = timeout or self .timeout ,
132+ ** kwargs )
124133 if resp .status_code >= 500 :
125134 resp .raise_for_status () # TODO: Will probably retry here
126135 try :
@@ -164,7 +173,7 @@ class Client(BaseClient): # We choose to implement all 4 grants in 1 class
164173 GRANT_TYPE_SAML2 = "urn:ietf:params:oauth:grant-type:saml2-bearer" # RFC7522
165174 GRANT_TYPE_JWT = "urn:ietf:params:oauth:grant-type:jwt-bearer" # RFC7523
166175
167- def initiate_device_flow (self , scope = None , ** kwargs ):
176+ def initiate_device_flow (self , scope = None , timeout = None , ** kwargs ):
168177 # type: (list, **dict) -> dict
169178 # The naming of this method is following the wording of this specs
170179 # https://tools.ietf.org/html/draft-ietf-oauth-device-flow-12#section-3.1
@@ -182,8 +191,9 @@ def initiate_device_flow(self, scope=None, **kwargs):
182191 DAE = "device_authorization_endpoint"
183192 if not self .configuration .get (DAE ):
184193 raise ValueError ("You need to provide device authorization endpoint" )
185- flow = requests . post (self .configuration [DAE ], headers = self . default_headers ,
194+ flow = self . session . post (self .configuration [DAE ],
186195 data = {"client_id" : self .client_id , "scope" : self ._stringify (scope or [])},
196+ timeout = timeout or self .timeout ,
187197 ** kwargs ).json ()
188198 flow ["interval" ] = int (flow .get ("interval" , 5 )) # Some IdP returns string
189199 flow ["expires_in" ] = int (flow .get ("expires_in" , 1800 ))
0 commit comments