Skip to content

Commit 4347bfd

Browse files
committed
test cases and more robust err handling
1 parent c3e5855 commit 4347bfd

File tree

5 files changed

+109
-49
lines changed

5 files changed

+109
-49
lines changed

ipinfo/handler.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,9 @@ def getBatchDetails(
179179
timeout_total is not None
180180
and time.time() - start_time > timeout_total
181181
):
182-
if raise_on_fail:
183-
raise TimeoutExceededError()
184-
else:
185-
return result
182+
return handler_utils.return_or_fail(
183+
raise_on_fail, TimeoutExceededError(), result
184+
)
186185

187186
chunk = lookup_addresses[i : i + batch_size]
188187

@@ -197,10 +196,7 @@ def getBatchDetails(
197196
raise RequestQuotaExceededError()
198197
response.raise_for_status()
199198
except Exception as e:
200-
if raise_on_fail:
201-
raise e
202-
else:
203-
return result
199+
return handler_utils.return_or_fail(raise_on_fail, e, result)
204200

205201
# fill cache
206202
json_response = response.json()

ipinfo/handler_async.py

Lines changed: 71 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
import json
88
import os
99
import sys
10+
import time
1011

1112
import aiohttp
1213

1314
from .cache.default import DefaultCache
1415
from .details import Details
15-
from .exceptions import RequestQuotaExceededError
16+
from .exceptions import RequestQuotaExceededError, TimeoutExceededError
1617
from .handler_utils import (
1718
API_URL,
1819
COUNTRY_FILE_DEFAULT,
@@ -197,49 +198,80 @@ async def getBatchDetails(
197198
url = API_URL + "/batch"
198199
headers = handler_utils.get_headers(self.access_token)
199200
headers["content-type"] = "application/json"
200-
reqs = []
201-
for i in range(0, len(lookup_addresses), batch_size):
202-
chunk = lookup_addresses[i : i + batch_size]
203-
204-
# do http req
205-
reqs.append(
206-
self.httpsess.post(
207-
url,
208-
data=json.dumps(chunk),
209-
headers=headers,
210-
timeout=timeout_per_batch,
211-
)
201+
202+
# prepare coroutines that will make reqs and update results.
203+
reqs = [
204+
self._do_batch_req(
205+
lookup_addresses[i : i + batch_size],
206+
url,
207+
headers,
208+
timeout_per_batch,
209+
raise_on_fail,
210+
result,
211+
)
212+
for i in range(0, len(lookup_addresses), batch_size)
213+
]
214+
215+
try:
216+
_, pending = await asyncio.wait(
217+
{*reqs},
218+
timeout=timeout_total,
219+
return_when=asyncio.FIRST_EXCEPTION,
212220
)
213221

214-
resps = await asyncio.wait_for(
215-
asyncio.gather(*reqs, return_exceptions=raise_on_fail),
216-
timeout_total
217-
)
218-
for resp in resps:
219-
# gather data
220-
try:
221-
if resp.status == 429:
222-
raise RequestQuotaExceededError()
223-
resp.raise_for_status()
224-
except Exception as e:
225-
if raise_on_fail:
226-
raise e
227-
else:
228-
return result
229-
230-
json_resp = await resp.json()
231-
232-
# format & fill up cache
233-
for ip_address, details in json_resp.items():
234-
if isinstance(details, dict):
235-
handler_utils.format_details(details, self.countries)
236-
self.cache[ip_address] = details
237-
238-
# merge cached results with new lookup
239-
result.update(json_resp)
222+
# if all done, return result.
223+
if len(pending) == 0:
224+
return result
225+
226+
# if some had a timeout, first cancel timed out stuff and wait for
227+
# cleanup. then exit with return_or_fail.
228+
for co in pending:
229+
try:
230+
co.cancel()
231+
await co
232+
except asyncio.CancelledError:
233+
pass
234+
235+
return handler_utils.return_or_fail(
236+
raise_on_fail, TimeoutExceededError(), result
237+
)
238+
except Exception as e:
239+
return handler_utils.return_or_fail(raise_on_fail, e, result)
240240

241241
return result
242242

243+
async def _do_batch_req(
244+
self, chunk, url, headers, timeout_per_batch, raise_on_fail, result
245+
):
246+
"""
247+
Coroutine which will do the actual POST request for getBatchDetails.
248+
"""
249+
resp = await self.httpsess.post(
250+
url,
251+
data=json.dumps(chunk),
252+
headers=headers,
253+
timeout=timeout_per_batch,
254+
)
255+
256+
# gather data
257+
try:
258+
if resp.status == 429:
259+
raise RequestQuotaExceededError()
260+
resp.raise_for_status()
261+
except Exception as e:
262+
return handler_utils.return_or_fail(raise_on_fail, e, None)
263+
264+
json_resp = await resp.json()
265+
266+
# format & fill up cache
267+
for ip_address, details in json_resp.items():
268+
if isinstance(details, dict):
269+
handler_utils.format_details(details, self.countries)
270+
self.cache[ip_address] = details
271+
272+
# merge cached results with new lookup
273+
result.update(json_resp)
274+
243275
def _ensure_aiohttp_ready(self):
244276
"""Ensures aiohttp internal state is initialized."""
245277
if self.httpsess:

ipinfo/handler_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,13 @@ def read_country_names(countries_file=None):
8383
countries_json = f.read()
8484

8585
return json.loads(countries_json)
86+
87+
88+
def return_or_fail(raise_on_fail, e, v):
89+
"""
90+
Either throws `e` if `raise_on_fail` or else returns `v`.
91+
"""
92+
if raise_on_fail:
93+
raise e
94+
else:
95+
return v

tests/handler_async_test.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ipinfo.details import Details
55
from ipinfo.handler_async import AsyncHandler
66
from ipinfo import handler_utils
7+
import ipinfo
78
import pytest
89

910

@@ -117,10 +118,21 @@ def _check_batch_details(ips, details, token):
117118
assert "domains" in d
118119

119120

120-
@pytest.mark.parametrize("batch_size", [None, 2, 3])
121+
@pytest.mark.parametrize("batch_size", [None, 1, 2, 3])
121122
@pytest.mark.asyncio
122123
async def test_get_batch_details(batch_size):
123124
handler, token, ips = _prepare_batch_test()
124125
details = await handler.getBatchDetails(ips, batch_size=batch_size)
125126
_check_batch_details(ips, details, token)
126127
await handler.deinit()
128+
129+
130+
@pytest.mark.parametrize("batch_size", [None, 1, 2, 3])
131+
@pytest.mark.asyncio
132+
async def test_get_batch_details_total_timeout(batch_size):
133+
handler, token, ips = _prepare_batch_test()
134+
with pytest.raises(ipinfo.exceptions.TimeoutExceededError):
135+
await handler.getBatchDetails(
136+
ips, batch_size=batch_size, timeout_total=0.001
137+
)
138+
await handler.deinit()

tests/handler_test.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ipinfo.details import Details
77
from ipinfo.handler import Handler
88
from ipinfo import handler_utils
9+
import ipinfo
910
import pytest
1011

1112

@@ -112,8 +113,17 @@ def _check_batch_details(ips, details, token):
112113
assert "domains" in d
113114

114115

115-
@pytest.mark.parametrize("batch_size", [None, 2, 3])
116+
@pytest.mark.parametrize("batch_size", [None, 1, 2, 3])
116117
def test_get_batch_details(batch_size):
117118
handler, token, ips = _prepare_batch_test()
118119
details = handler.getBatchDetails(ips, batch_size=batch_size)
119120
_check_batch_details(ips, details, token)
121+
122+
123+
@pytest.mark.parametrize("batch_size", [1, 2])
124+
def test_get_batch_details_total_timeout(batch_size):
125+
handler, token, ips = _prepare_batch_test()
126+
with pytest.raises(ipinfo.exceptions.TimeoutExceededError):
127+
handler.getBatchDetails(
128+
ips, batch_size=batch_size, timeout_total=0.001
129+
)

0 commit comments

Comments
 (0)