Skip to content

Commit cd210bc

Browse files
committed
allowed custom headers
1 parent e6376dd commit cd210bc

File tree

5 files changed

+27
-14
lines changed

5 files changed

+27
-14
lines changed

ipinfo/handler.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ def __init__(self, access_token=None, **kwargs):
9595
cache_options["ttl"] = CACHE_TTL
9696
self.cache = DefaultCache(**cache_options)
9797

98+
# setup custom headers
99+
self.headers = kwargs.get("headers", None)
100+
98101
def getDetails(self, ip_address=None, timeout=None):
99102
"""
100103
Get details for specified IP address as a Details object.
@@ -133,7 +136,7 @@ def getDetails(self, ip_address=None, timeout=None):
133136
url = API_URL
134137
if ip_address:
135138
url += "/" + ip_address
136-
headers = handler_utils.get_headers(self.access_token)
139+
headers = handler_utils.get_headers(self.access_token, self.headers)
137140
response = requests.get(url, headers=headers, **req_opts)
138141
if response.status_code == 429:
139142
raise RequestQuotaExceededError()
@@ -226,7 +229,7 @@ def getBatchDetails(
226229

227230
# loop over batch chunks and do lookup for each.
228231
url = API_URL + "/batch"
229-
headers = handler_utils.get_headers(self.access_token)
232+
headers = handler_utils.get_headers(self.access_token, self.headers)
230233
headers["content-type"] = "application/json"
231234
for i in range(0, len(lookup_addresses), batch_size):
232235
# quit if total timeout is reached.
@@ -295,7 +298,7 @@ def getMap(self, ips):
295298

296299
req_opts = {**self.request_options}
297300
url = f"{API_URL}/map?cli=1"
298-
headers = handler_utils.get_headers(None)
301+
headers = handler_utils.get_headers(None, self.headers)
299302
headers["content-type"] = "application/json"
300303
response = requests.post(
301304
url, json=ip_strs, headers=headers, **req_opts

ipinfo/handler_async.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,10 @@ def __init__(self, access_token=None, **kwargs):
9898
if "ttl" not in cache_options:
9999
cache_options["ttl"] = CACHE_TTL
100100
self.cache = DefaultCache(**cache_options)
101-
101+
102+
# setup custom headers
103+
self.headers = kwargs.get("headers", None)
104+
102105
async def init(self):
103106
"""
104107
Initializes internal aiohttp connection pool.
@@ -153,7 +156,7 @@ async def getDetails(self, ip_address=None, timeout=None):
153156
url = API_URL
154157
if ip_address:
155158
url += "/" + ip_address
156-
headers = handler_utils.get_headers(self.access_token)
159+
headers = handler_utils.get_headers(self.access_token, self.headers)
157160
req_opts = {}
158161
if timeout is not None:
159162
req_opts["timeout"] = timeout
@@ -251,7 +254,7 @@ async def getBatchDetails(
251254

252255
# loop over batch chunks and prepare coroutines for each.
253256
url = API_URL + "/batch"
254-
headers = handler_utils.get_headers(self.access_token)
257+
headers = handler_utils.get_headers(self.access_token, self.headers)
255258
headers["content-type"] = "application/json"
256259

257260
# prepare coroutines that will make reqs and update results.

ipinfo/handler_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,19 @@
4444
BATCH_REQ_TIMEOUT_DEFAULT = 5
4545

4646

47-
def get_headers(access_token):
48-
"""Build headers for request to IPinfo API."""
49-
headers = {
47+
def get_headers(access_token, custom_headers):
48+
headers = {}
49+
50+
if custom_headers:
51+
headers = custom_headers
52+
else:
53+
"""Build headers for request to IPinfo API."""
54+
headers = {
5055
"user-agent": "IPinfoClient/Python{version}/{sdk_version}".format(
5156
version=sys.version_info[0], sdk_version=SDK_VERSION
5257
),
5358
"accept": "application/json",
54-
}
59+
}
5560

5661
if access_token:
5762
headers["authorization"] = "Bearer {}".format(access_token)

tests/handler_async_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ async def test_init():
2121
@pytest.mark.asyncio
2222
async def test_headers():
2323
token = "mytesttoken"
24-
handler = AsyncHandler(token)
25-
headers = handler_utils.get_headers(token)
24+
handler = AsyncHandler(token, headers={"user-agent": "test-agent", "accept": "application/json", "custom_field": "yes"})
25+
headers = handler_utils.get_headers(token, handler.headers)
2626
await handler.deinit()
2727

2828
assert "user-agent" in headers
2929
assert "accept" in headers
3030
assert "authorization" in headers
31+
assert "custom_field" in headers
3132

3233

3334
@pytest.mark.asyncio

tests/handler_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@ def test_init():
2020

2121
def test_headers():
2222
token = "mytesttoken"
23-
handler = Handler(token)
24-
headers = handler_utils.get_headers(token)
23+
handler = Handler(token, headers={"user-agent": "test-agent", "accept": "application/json", "custom_field": "yes"})
24+
headers = handler_utils.get_headers(token, handler.headers)
2525

2626
assert "user-agent" in headers
2727
assert "accept" in headers
2828
assert "authorization" in headers
29+
assert "custom_field" in headers
2930

3031

3132
def test_get_details():

0 commit comments

Comments
 (0)