Skip to content

Commit f4678ff

Browse files
committed
Support default and adhoc headers, relay **kwargs to requests
1 parent 2ea4a36 commit f4678ff

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

oauth2cli/oauth2.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(
2525
client_secret=None, # type: Optional[str]
2626
client_assertion=None, # type: Optional[str]
2727
client_assertion_type=None, # type: Optional[str]
28+
default_headers=None, # type: Optional[dict]
2829
default_body=None, # type: Optional[dict]
2930
):
3031
"""Initialize a client object to talk all the OAuth2 grants to the server.
@@ -49,6 +50,9 @@ def __init__(
4950
a guess between SAML2 (RFC 7522) and JWT (RFC 7523),
5051
the only two profiles defined in RFC 7521.
5152
But you can also explicitly provide a value, if needed.
53+
default_headers (dict):
54+
A dict to be sent in each request header.
55+
It is not required by OAuth2 specs, but you may use it for telemetry.
5256
default_body (dict):
5357
A dict to be sent in each token request body. For example,
5458
you could choose to set this as {"client_secret": "your secret"}
@@ -58,6 +62,7 @@ def __init__(
5862
self.configuration = server_configuration
5963
self.client_id = client_id
6064
self.client_secret = client_secret
65+
self.default_headers = default_headers or {}
6166
self.default_body = default_body or {}
6267
if client_assertion is not None: # See https://tools.ietf.org/html/rfc7521#section-4.2
6368
TYPE_JWT = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
@@ -84,9 +89,10 @@ def _build_auth_request_params(self, response_type, **kwargs):
8489

8590
def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
8691
self, grant_type,
87-
params=None, # a dict to be send as query string to the endpoint
92+
params=None, # a dict to be sent as query string to the endpoint
8893
data=None, # All relevant data, which will go into the http body
89-
timeout=None, # A timeout value which will be used by requests lib
94+
headers=None, # a dict to be sent as request headers
95+
**kwargs # Relay all extra parameters to underlying requests
9096
): # Returns the json object came from the OAUTH2 response
9197
_data = {'client_id': self.client_id, 'grant_type': grant_type}
9298
_data.update(self.default_body) # It may contain authen parameters
@@ -109,10 +115,12 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
109115

110116
if "token_endpoint" not in self.configuration:
111117
raise ValueError("token_endpoint not found in configuration")
118+
_headers = {'Accept': 'application/json'}
119+
_headers.update(self.default_headers)
120+
_headers.update(headers or {})
112121
resp = requests.post(
113122
self.configuration["token_endpoint"],
114-
headers={'Accept': 'application/json'},
115-
params=params, data=_data, auth=auth, timeout=timeout)
123+
headers=_headers, params=params, data=_data, auth=auth, **kwargs)
116124
if resp.status_code >= 500:
117125
resp.raise_for_status() # TODO: Will probably retry here
118126
try:
@@ -174,9 +182,9 @@ def initiate_device_flow(self, scope=None, **kwargs):
174182
DAE = "device_authorization_endpoint"
175183
if not self.configuration.get(DAE):
176184
raise ValueError("You need to provide device authorization endpoint")
177-
flow = requests.post(self.configuration[DAE], data={
178-
"client_id": self.client_id, "scope": self._stringify(scope or []),
179-
}, **kwargs).json()
185+
flow = requests.post(self.configuration[DAE], headers=self.default_headers,
186+
data={"client_id": self.client_id, "scope": self._stringify(scope or [])},
187+
**kwargs).json()
180188
flow["interval"] = int(flow.get("interval", 5)) # Some IdP returns string
181189
flow["expires_in"] = int(flow.get("expires_in", 1800))
182190
return flow

0 commit comments

Comments
 (0)