@@ -25,6 +25,7 @@ def __init__(
2525 client_secret = None , # type: Optional[str]
2626 client_assertion = None , # type: Optional[str]
2727 client_assertion_type = None , # type: Optional[str]
28+ default_headers = None , # type: Optional[dict]
2829 default_body = None , # type: Optional[dict]
2930 ):
3031 """Initialize a client object to talk all the OAuth2 grants to the server.
@@ -49,6 +50,9 @@ def __init__(
4950 a guess between SAML2 (RFC 7522) and JWT (RFC 7523),
5051 the only two profiles defined in RFC 7521.
5152 But you can also explicitly provide a value, if needed.
53+ default_headers (dict):
54+ A dict to be sent in each request header.
55+ It is not required by OAuth2 specs, but you may use it for telemetry.
5256 default_body (dict):
5357 A dict to be sent in each token request body. For example,
5458 you could choose to set this as {"client_secret": "your secret"}
@@ -58,6 +62,7 @@ def __init__(
5862 self .configuration = server_configuration
5963 self .client_id = client_id
6064 self .client_secret = client_secret
65+ self .default_headers = default_headers or {}
6166 self .default_body = default_body or {}
6267 if client_assertion is not None : # See https://tools.ietf.org/html/rfc7521#section-4.2
6368 TYPE_JWT = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
@@ -84,9 +89,10 @@ def _build_auth_request_params(self, response_type, **kwargs):
8489
8590 def _obtain_token ( # The verb "obtain" is influenced by OAUTH2 RFC 6749
8691 self , grant_type ,
87- params = None , # a dict to be send as query string to the endpoint
92+ params = None , # a dict to be sent as query string to the endpoint
8893 data = None , # All relevant data, which will go into the http body
89- timeout = None , # A timeout value which will be used by requests lib
94+ headers = None , # a dict to be sent as request headers
95+ ** kwargs # Relay all extra parameters to underlying requests
9096 ): # Returns the json object came from the OAUTH2 response
9197 _data = {'client_id' : self .client_id , 'grant_type' : grant_type }
9298 _data .update (self .default_body ) # It may contain authen parameters
@@ -109,10 +115,12 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
109115
110116 if "token_endpoint" not in self .configuration :
111117 raise ValueError ("token_endpoint not found in configuration" )
118+ _headers = {'Accept' : 'application/json' }
119+ _headers .update (self .default_headers )
120+ _headers .update (headers or {})
112121 resp = requests .post (
113122 self .configuration ["token_endpoint" ],
114- headers = {'Accept' : 'application/json' },
115- params = params , data = _data , auth = auth , timeout = timeout )
123+ headers = _headers , params = params , data = _data , auth = auth , ** kwargs )
116124 if resp .status_code >= 500 :
117125 resp .raise_for_status () # TODO: Will probably retry here
118126 try :
@@ -174,9 +182,9 @@ def initiate_device_flow(self, scope=None, **kwargs):
174182 DAE = "device_authorization_endpoint"
175183 if not self .configuration .get (DAE ):
176184 raise ValueError ("You need to provide device authorization endpoint" )
177- flow = requests .post (self .configuration [DAE ], data = {
178- "client_id" : self .client_id , "scope" : self ._stringify (scope or []),
179- }, ** kwargs ).json ()
185+ flow = requests .post (self .configuration [DAE ], headers = self . default_headers ,
186+ data = { "client_id" : self .client_id , "scope" : self ._stringify (scope or [])} ,
187+ ** kwargs ).json ()
180188 flow ["interval" ] = int (flow .get ("interval" , 5 )) # Some IdP returns string
181189 flow ["expires_in" ] = int (flow .get ("expires_in" , 1800 ))
182190 return flow
0 commit comments