Skip to content

Commit 14d1ceb

Browse files
committed
basic batch improvements.
chunking happens now; input list size limit doesnt exist.
1 parent 5d363ea commit 14d1ceb

File tree

6 files changed

+155
-72
lines changed

6 files changed

+155
-72
lines changed

ipinfo/handler.py

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,28 @@ def getDetails(self, ip_address=None):
8383

8484
return Details(details)
8585

86-
def getBatchDetails(self, ip_addresses):
87-
"""Get details for a batch of IP addresses at once."""
86+
def getBatchDetails(self, ip_addresses, batch_size=None):
87+
"""
88+
Get details for a batch of IP addresses at once.
89+
90+
There is no specified limit to the number of IPs this function can
91+
accept; it can handle as much as the user can fit in RAM (along with
92+
all of the response data, which is at least a magnitude larger than the
93+
input list).
94+
95+
The batch size can be adjusted with `batch_size` but is clipped to (and
96+
also defaults to) `handler_utils.BATCH_MAX_SIZE`.
97+
"""
98+
if batch_size == None:
99+
batch_size = handler_utils.BATCH_MAX_SIZE
100+
88101
result = {}
89102

90-
# Pre-populate with anything we've got in the cache, and keep around
103+
# pre-populate with anything we've got in the cache, and keep around
91104
# the IPs not in the cache.
92105
lookup_addresses = []
93106
for ip_address in ip_addresses:
94-
# If the supplied IP address uses the objects defined in the
107+
# if the supplied IP address uses the objects defined in the
95108
# built-in module ipaddress extract the appropriate string notation
96109
# before formatting the URL.
97110
if isinstance(ip_address, IPv4Address) or isinstance(
@@ -104,28 +117,35 @@ def getBatchDetails(self, ip_addresses):
104117
else:
105118
lookup_addresses.append(ip_address)
106119

107-
# Do the lookup
108-
url = handler_utils.API_URL + "/batch"
109-
headers = handler_utils.get_headers(self.access_token)
110-
headers["content-type"] = "application/json"
111-
response = requests.post(
112-
url, json=lookup_addresses, headers=headers, **self.request_options
113-
)
114-
if response.status_code == 429:
115-
raise RequestQuotaExceededError()
116-
response.raise_for_status()
117-
118-
# Fill up cache
119-
json_response = response.json()
120-
for ip_address, details in json_response.items():
121-
self.cache[ip_address] = details
122-
123-
# Merge cached results with new lookup
124-
result.update(json_response)
125-
126-
# Format every result
127-
for detail in result.values():
128-
if isinstance(detail, dict):
129-
handler_utils.format_details(detail, self.countries)
120+
# loop over batch chunks and do lookup for each.
121+
for i in range(0, len(ip_addresses), batch_size):
122+
chunk = ip_addresses[i : i + batch_size]
123+
124+
# lookup
125+
url = handler_utils.API_URL + "/batch"
126+
headers = handler_utils.get_headers(self.access_token)
127+
headers["content-type"] = "application/json"
128+
response = requests.post(
129+
url,
130+
json=lookup_addresses,
131+
headers=headers,
132+
**self.request_options
133+
)
134+
if response.status_code == 429:
135+
raise RequestQuotaExceededError()
136+
response.raise_for_status()
137+
138+
# fill cache
139+
json_response = response.json()
140+
for ip_address, details in json_response.items():
141+
self.cache[ip_address] = details
142+
143+
# merge cached results with new lookup
144+
result.update(json_response)
145+
146+
# format all
147+
for detail in result.values():
148+
if isinstance(detail, dict):
149+
handler_utils.format_details(detail, self.countries)
130150

131151
return result

ipinfo/handler_async.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from ipaddress import IPv4Address, IPv6Address
6+
import asyncio
67
import json
78
import os
89
import sys
@@ -112,10 +113,26 @@ async def getDetails(self, ip_address=None):
112113

113114
return Details(details)
114115

115-
async def getBatchDetails(self, ip_addresses):
116-
"""Get details for a batch of IP addresses at once."""
116+
async def getBatchDetails(self, ip_addresses, batch_size=None):
117+
"""
118+
Get details for a batch of IP addresses at once.
119+
120+
There is no specified limit to the number of IPs this function can
121+
accept; it can handle as much as the user can fit in RAM (along with
122+
all of the response data, which is at least a magnitude larger than the
123+
input list).
124+
125+
The batch size can be adjusted with `batch_size` but is clipped to (and
126+
also defaults to) `handler_utils.BATCH_MAX_SIZE`.
127+
128+
The concurrency level is currently unadjustable; coroutines will be
129+
created and consumed for all batches at once.
130+
"""
117131
self._ensure_aiohttp_ready()
118132

133+
if batch_size == None:
134+
batch_size = handler_utils.BATCH_MAX_SIZE
135+
119136
result = {}
120137

121138
# Pre-populate with anything we've got in the cache, and keep around
@@ -135,32 +152,41 @@ async def getBatchDetails(self, ip_addresses):
135152
else:
136153
lookup_addresses.append(ip_address)
137154

138-
# all in cache - return early.
139-
if len(lookup_addresses) == 0:
140-
return result
141-
142-
# do http req
143-
url = handler_utils.API_URL + "/batch"
144-
headers = handler_utils.get_headers(self.access_token)
145-
headers["content-type"] = "application/json"
146-
async with self.httpsess.post(
147-
url,
148-
data=json.dumps(lookup_addresses),
149-
headers=headers
150-
) as resp:
155+
# loop over batch chunks and prepare coroutines for each.
156+
reqs = []
157+
for i in range(0, len(ip_addresses), batch_size):
158+
chunk = ip_addresses[i : i + batch_size]
159+
160+
# all in cache - return early.
161+
if len(lookup_addresses) == 0:
162+
return result
163+
164+
# do http req
165+
url = handler_utils.API_URL + "/batch"
166+
headers = handler_utils.get_headers(self.access_token)
167+
headers["content-type"] = "application/json"
168+
reqs.append(
169+
self.httpsess.post(
170+
url, data=json.dumps(lookup_addresses), headers=headers
171+
)
172+
)
173+
174+
resps = await asyncio.gather(*reqs)
175+
for resp in resps:
176+
# gather data
151177
if resp.status == 429:
152178
raise RequestQuotaExceededError()
153179
resp.raise_for_status()
154180
json_resp = await resp.json()
155181

156-
# format & fill up cache
157-
for ip_address, details in json_resp.items():
158-
if isinstance(details, dict):
159-
handler_utils.format_details(details, self.countries)
160-
self.cache[ip_address] = details
182+
# format & fill up cache
183+
for ip_address, details in json_resp.items():
184+
if isinstance(details, dict):
185+
handler_utils.format_details(details, self.countries)
186+
self.cache[ip_address] = details
161187

162-
# merge cached results with new lookup
163-
result.update(json_resp)
188+
# merge cached results with new lookup
189+
result.update(json_resp)
164190

165191
return result
166192

ipinfo/handler_utils.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,22 @@
88

99
from .version import SDK_VERSION
1010

11+
# Base URL to make requests against.
1112
API_URL = "https://ipinfo.io"
13+
14+
# Used to transform incoming responses with country abbreviations into the full
15+
# expanded country name, e.g. "PK" -> "Pakistan".
1216
COUNTRY_FILE_DEFAULT = "countries.json"
1317

18+
# The max amount of IPs allowed by the API per batch request.
19+
BATCH_MAX_SIZE = 1000
20+
21+
1422
def get_headers(access_token):
1523
"""Build headers for request to IPinfo API."""
1624
headers = {
1725
"user-agent": "IPinfoClient/Python{version}/{sdk_version}".format(
18-
version=sys.version_info[0],
19-
sdk_version=SDK_VERSION
26+
version=sys.version_info[0], sdk_version=SDK_VERSION
2027
),
2128
"accept": "application/json",
2229
}
@@ -26,16 +33,16 @@ def get_headers(access_token):
2633

2734
return headers
2835

36+
2937
def format_details(details, countries):
3038
"""
3139
Format details given a countries object.
3240
3341
The countries object can be retrieved from read_country_names.
3442
"""
3543
details["country_name"] = countries.get(details.get("country"))
36-
details["latitude"], details["longitude"] = read_coords(
37-
details.get("loc")
38-
)
44+
details["latitude"], details["longitude"] = read_coords(details.get("loc"))
45+
3946

4047
def read_coords(location):
4148
"""
@@ -50,6 +57,7 @@ def read_coords(location):
5057
lat, lon = coords[0], coords[1]
5158
return lat, lon
5259

60+
5361
def read_country_names(countries_file=None):
5462
"""
5563
Read list of countries from specified country file or

ipinfo/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
SDK_VERSION = '4.1.0'
1+
SDK_VERSION = "4.1.0"

tests/handler_async_test.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,8 @@ async def test_headers():
2929
assert "authorization" in headers
3030

3131

32-
@pytest.mark.parametrize("n", range(5))
3332
@pytest.mark.asyncio
34-
async def test_get_details(n):
33+
async def test_get_details():
3534
token = os.environ.get("IPINFO_TOKEN", "")
3635
handler = AsyncHandler(token)
3736
details = await handler.getDetails("8.8.8.8")
@@ -86,27 +85,42 @@ async def test_get_details(n):
8685
await handler.deinit()
8786

8887

89-
@pytest.mark.parametrize("n", range(5))
90-
@pytest.mark.asyncio
91-
async def test_get_batch_details(n):
88+
#############
89+
# BATCH TESTS
90+
#############
91+
92+
_batch_ip_addrs = ["1.1.1.1", "8.8.8.8", "9.9.9.9"]
93+
94+
95+
def _prepare_batch_test():
96+
"""Helper for preparing batch test cases."""
9297
token = os.environ.get("IPINFO_TOKEN", "")
9398
if not token:
9499
pytest.skip("token required for batch tests")
95100
handler = AsyncHandler(token)
96-
ips = ["1.1.1.1", "8.8.8.8", "9.9.9.9"]
97-
details = await handler.getBatchDetails(ips)
101+
return handler, token, _batch_ip_addrs
98102

103+
104+
def _check_batch_details(ips, details, token):
105+
"""Helper for batch tests."""
99106
for ip in ips:
100107
assert ip in details
101108
d = details[ip]
102109
assert d["ip"] == ip
103-
assert d["country"] == "US"
104-
assert d["country_name"] == "United States"
110+
assert "country" in d
111+
assert "country_name" in d
105112
if token:
106113
assert "asn" in d
107114
assert "company" in d
108115
assert "privacy" in d
109116
assert "abuse" in d
110117
assert "domains" in d
111118

119+
120+
@pytest.mark.parametrize("batch_size", [None, 2, 3])
121+
@pytest.mark.asyncio
122+
async def test_get_batch_details(batch_size):
123+
handler, token, ips = _prepare_batch_test()
124+
details = await handler.getBatchDetails(ips, batch_size=batch_size)
125+
_check_batch_details(ips, details, token)
112126
await handler.deinit()

tests/handler_test.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ def test_headers():
2727
assert "authorization" in headers
2828

2929

30-
@pytest.mark.parametrize("n", range(5))
31-
def test_get_details(n):
30+
def test_get_details():
3231
token = os.environ.get("IPINFO_TOKEN", "")
3332
handler = Handler(token)
3433
details = handler.getDetails("8.8.8.8")
@@ -81,24 +80,40 @@ def test_get_details(n):
8180
assert len(domains["domains"]) == 5
8281

8382

84-
@pytest.mark.parametrize("n", range(5))
85-
def test_get_batch_details(n):
83+
#############
84+
# BATCH TESTS
85+
#############
86+
87+
_batch_ip_addrs = ["1.1.1.1", "8.8.8.8", "9.9.9.9"]
88+
89+
90+
def _prepare_batch_test():
91+
"""Helper for preparing batch test cases."""
8692
token = os.environ.get("IPINFO_TOKEN", "")
8793
if not token:
8894
pytest.skip("token required for batch tests")
8995
handler = Handler(token)
90-
ips = ["1.1.1.1", "8.8.8.8", "9.9.9.9"]
91-
details = handler.getBatchDetails(ips)
96+
return handler, token, _batch_ip_addrs
9297

98+
99+
def _check_batch_details(ips, details, token):
100+
"""Helper for batch tests."""
93101
for ip in ips:
94102
assert ip in details
95103
d = details[ip]
96104
assert d["ip"] == ip
97-
assert d["country"] == "US"
98-
assert d["country_name"] == "United States"
105+
assert "country" in d
106+
assert "country_name" in d
99107
if token:
100108
assert "asn" in d
101109
assert "company" in d
102110
assert "privacy" in d
103111
assert "abuse" in d
104112
assert "domains" in d
113+
114+
115+
@pytest.mark.parametrize("batch_size", [None, 2, 3])
116+
def test_get_batch_details(batch_size):
117+
handler, token, ips = _prepare_batch_test()
118+
details = handler.getBatchDetails(ips, batch_size=batch_size)
119+
_check_batch_details(ips, details, token)

0 commit comments

Comments
 (0)