Skip to content

Commit 3b1bbda

Browse files
committed
refactor common handler utilities out
1 parent 1b98867 commit 3b1bbda

File tree

5 files changed

+84
-105
lines changed

5 files changed

+84
-105
lines changed

ipinfo/handler.py

Lines changed: 12 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .cache.default import DefaultCache
1313
from .details import Details
1414
from .exceptions import RequestQuotaExceededError
15+
from . import handler_utils
1516

1617

1718
class Handler:
@@ -20,10 +21,8 @@ class Handler:
2021
Instantiates and maintains access to cache.
2122
"""
2223

23-
API_URL = "https://ipinfo.io"
2424
CACHE_MAXSIZE = 4096
2525
CACHE_TTL = 60 * 60 * 24
26-
COUNTRY_FILE_DEFAULT = "countries.json"
2726
REQUEST_TIMEOUT_DEFAULT = 2
2827

2928
def __init__(self, access_token=None, **kwargs):
@@ -34,7 +33,9 @@ def __init__(self, access_token=None, **kwargs):
3433
self.access_token = access_token
3534

3635
# load countries file
37-
self.countries = self._read_country_names(kwargs.get("countries_file"))
36+
self.countries = handler_utils.read_country_names(
37+
kwargs.get("countries_file")
38+
)
3839

3940
# setup req opts
4041
self.request_options = kwargs.get("request_options", {})
@@ -55,7 +56,7 @@ def __init__(self, access_token=None, **kwargs):
5556
def getDetails(self, ip_address=None):
5657
"""Get details for specified IP address as a Details object."""
5758
raw_details = self._requestDetails(ip_address)
58-
self._format_details(raw_details)
59+
handler_utils.format_details(raw_details, self.countries)
5960
return Details(raw_details)
6061

6162
def getBatchDetails(self, ip_addresses):
@@ -80,8 +81,8 @@ def getBatchDetails(self, ip_addresses):
8081
lookup_addresses.append(ip_address)
8182

8283
# Do the lookup
83-
url = self.API_URL + "/batch"
84-
headers = self._get_headers()
84+
url = handler_utils.API_URL + "/batch"
85+
headers = handler_utils.get_headers(self.access_token)
8586
headers["content-type"] = "application/json"
8687
response = requests.post(
8788
url, json=lookup_addresses, headers=headers, **self.request_options
@@ -101,7 +102,7 @@ def getBatchDetails(self, ip_addresses):
101102
# Format every result
102103
for detail in result.values():
103104
if isinstance(detail, dict):
104-
self._format_details(detail)
105+
handler_utils.format_details(detail, self.countries)
105106

106107
return result
107108

@@ -117,57 +118,18 @@ def _requestDetails(self, ip_address=None):
117118
ip_address = ip_address.exploded
118119

119120
if ip_address not in self.cache:
120-
url = self.API_URL
121+
url = handler_utils.API_URL
121122
if ip_address:
122123
url += "/" + ip_address
123124

124125
response = requests.get(
125-
url, headers=self._get_headers(), **self.request_options
126+
url,
127+
headers=handler_utils.get_headers(self.access_token),
128+
**self.request_options
126129
)
127130
if response.status_code == 429:
128131
raise RequestQuotaExceededError()
129132
response.raise_for_status()
130133
self.cache[ip_address] = response.json()
131134

132135
return self.cache[ip_address]
133-
134-
def _get_headers(self):
135-
"""Built headers for request to IPinfo API."""
136-
headers = {
137-
"user-agent": "IPinfoClient/Python{version}/4.0.0".format(
138-
version=sys.version_info[0]
139-
),
140-
"accept": "application/json",
141-
}
142-
143-
if self.access_token:
144-
headers["authorization"] = "Bearer {}".format(self.access_token)
145-
146-
return headers
147-
148-
def _format_details(self, details):
149-
details["country_name"] = self.countries.get(details.get("country"))
150-
details["latitude"], details["longitude"] = self._read_coords(
151-
details.get("loc")
152-
)
153-
154-
def _read_coords(self, location):
155-
lat, lon = None, None
156-
coords = tuple(location.split(",")) if location else ""
157-
if len(coords) == 2 and coords[0] and coords[1]:
158-
lat, lon = coords[0], coords[1]
159-
return lat, lon
160-
161-
def _read_country_names(self, countries_file=None):
162-
"""
163-
Read list of countries from specified country file or
164-
default file.
165-
"""
166-
if not countries_file:
167-
countries_file = os.path.join(
168-
os.path.dirname(__file__), self.COUNTRY_FILE_DEFAULT
169-
)
170-
with open(countries_file) as f:
171-
countries_json = f.read()
172-
173-
return json.loads(countries_json)

ipinfo/handler_async.py

Lines changed: 13 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .cache.default import DefaultCache
1313
from .details import Details
1414
from .exceptions import RequestQuotaExceededError
15+
from . import handler_utils
1516

1617

1718
class AsyncHandler:
@@ -20,10 +21,8 @@ class AsyncHandler:
2021
Instantiates and maintains access to cache.
2122
"""
2223

23-
API_URL = "https://ipinfo.io"
2424
CACHE_MAXSIZE = 4096
2525
CACHE_TTL = 60 * 60 * 24
26-
COUNTRY_FILE_DEFAULT = "countries.json"
2726
REQUEST_TIMEOUT_DEFAULT = 2
2827

2928
def __init__(self, access_token=None, **kwargs):
@@ -34,7 +33,9 @@ def __init__(self, access_token=None, **kwargs):
3433
self.access_token = access_token
3534

3635
# load countries file
37-
self.countries = self._read_country_names(kwargs.get("countries_file"))
36+
self.countries = handler_utils.read_country_names(
37+
kwargs.get("countries_file")
38+
)
3839

3940
# setup req opts
4041
self.request_options = kwargs.get("request_options", {})
@@ -95,18 +96,18 @@ async def getDetails(self, ip_address=None):
9596
return Details(self.cache[ip_address])
9697

9798
# not in cache; do http req
98-
url = self.API_URL
99+
url = handler_utils.API_URL
99100
if ip_address:
100101
url += "/" + ip_address
101-
headers = self._get_headers()
102+
headers = handler_utils.get_headers(self.access_token)
102103
async with self.httpsess.get(url, headers=headers) as resp:
103104
if resp.status == 429:
104105
raise RequestQuotaExceededError()
105106
resp.raise_for_status()
106107
raw_details = await resp.json()
107108

108109
# format & cache
109-
self._format_details(raw_details)
110+
handler_utils.format_details(raw_details, self.countries)
110111
self.cache[ip_address] = raw_details
111112

112113
return Details(raw_details)
@@ -139,11 +140,13 @@ async def getBatchDetails(self, ip_addresses):
139140
return result
140141

141142
# do http req
142-
url = self.API_URL + "/batch"
143-
headers = self._get_headers()
143+
url = handler_utils.API_URL + "/batch"
144+
headers = handler_utils.get_headers(self.access_token)
144145
headers["content-type"] = "application/json"
145146
async with self.httpsess.post(
146-
url, data=json.dumps(lookup_addresses), headers=headers
147+
url,
148+
data=json.dumps(lookup_addresses),
149+
headers=headers
147150
) as resp:
148151
if resp.status == 429:
149152
raise RequestQuotaExceededError()
@@ -153,7 +156,7 @@ async def getBatchDetails(self, ip_addresses):
153156
# format & fill up cache
154157
for ip_address, details in json_resp.items():
155158
if isinstance(details, dict):
156-
self._format_details(details)
159+
handler_utils.format_details(details, self.countries)
157160
self.cache[ip_address] = details
158161

159162
# merge cached results with new lookup
@@ -168,44 +171,3 @@ def _ensure_aiohttp_ready(self):
168171

169172
timeout = aiohttp.ClientTimeout(total=self.request_options["timeout"])
170173
self.httpsess = aiohttp.ClientSession(timeout=timeout)
171-
172-
def _get_headers(self):
173-
"""Built headers for request to IPinfo API."""
174-
headers = {
175-
"user-agent": "IPinfoClient/Python{version}/4.0.0".format(
176-
version=sys.version_info[0]
177-
),
178-
"accept": "application/json",
179-
}
180-
181-
if self.access_token:
182-
headers["authorization"] = "Bearer {}".format(self.access_token)
183-
184-
return headers
185-
186-
def _format_details(self, details):
187-
details["country_name"] = self.countries.get(details.get("country"))
188-
details["latitude"], details["longitude"] = self._read_coords(
189-
details.get("loc")
190-
)
191-
192-
def _read_coords(self, location):
193-
lat, lon = None, None
194-
coords = tuple(location.split(",")) if location else ""
195-
if len(coords) == 2 and coords[0] and coords[1]:
196-
lat, lon = coords[0], coords[1]
197-
return lat, lon
198-
199-
def _read_country_names(self, countries_file=None):
200-
"""
201-
Read list of countries from specified country file or
202-
default file.
203-
"""
204-
if not countries_file:
205-
countries_file = os.path.join(
206-
os.path.dirname(__file__), self.COUNTRY_FILE_DEFAULT
207-
)
208-
with open(countries_file) as f:
209-
countries_json = f.read()
210-
211-
return json.loads(countries_json)

ipinfo/handler_utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""
2+
Utilities used in handlers.
3+
"""
4+
5+
import json
6+
import os
7+
import sys
8+
9+
API_URL = "https://ipinfo.io"
10+
COUNTRY_FILE_DEFAULT = "countries.json"
11+
12+
def get_headers(access_token):
13+
"""Build headers for request to IPinfo API."""
14+
headers = {
15+
"user-agent": "IPinfoClient/Python{version}/4.0.0".format(
16+
version=sys.version_info[0]
17+
),
18+
"accept": "application/json",
19+
}
20+
21+
if access_token:
22+
headers["authorization"] = "Bearer {}".format(access_token)
23+
24+
return headers
25+
26+
def format_details(details, countries):
27+
details["country_name"] = countries.get(details.get("country"))
28+
details["latitude"], details["longitude"] = read_coords(
29+
details.get("loc")
30+
)
31+
32+
def read_coords(location):
33+
lat, lon = None, None
34+
coords = tuple(location.split(",")) if location else ""
35+
if len(coords) == 2 and coords[0] and coords[1]:
36+
lat, lon = coords[0], coords[1]
37+
return lat, lon
38+
39+
def read_country_names(countries_file=None):
40+
"""
41+
Read list of countries from specified country file or
42+
default file.
43+
"""
44+
if not countries_file:
45+
countries_file = os.path.join(
46+
os.path.dirname(__file__), COUNTRY_FILE_DEFAULT
47+
)
48+
with open(countries_file) as f:
49+
countries_json = f.read()
50+
51+
return json.loads(countries_json)

tests/handler_async_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ipinfo.cache.default import DefaultCache
44
from ipinfo.details import Details
55
from ipinfo.handler_async import AsyncHandler
6+
from ipinfo import handler_utils
67
import pytest
78

89

@@ -20,7 +21,7 @@ async def test_init():
2021
async def test_headers():
2122
token = "mytesttoken"
2223
handler = AsyncHandler(token)
23-
headers = handler._get_headers()
24+
headers = handler_utils.get_headers(token)
2425
await handler.deinit()
2526

2627
assert "user-agent" in headers
@@ -78,7 +79,8 @@ async def test_get_details(n):
7879

7980
domains = details.domains
8081
assert domains["ip"] == "8.8.8.8"
81-
assert domains["total"] == 12988
82+
# NOTE: actual number changes too much
83+
assert "total" in domains
8284
assert len(domains["domains"]) == 5
8385

8486
await handler.deinit()

tests/handler_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ipinfo.cache.default import DefaultCache
66
from ipinfo.details import Details
77
from ipinfo.handler import Handler
8+
from ipinfo import handler_utils
89
import pytest
910

1011

@@ -19,7 +20,7 @@ def test_init():
1920
def test_headers():
2021
token = "mytesttoken"
2122
handler = Handler(token)
22-
headers = handler._get_headers()
23+
headers = handler_utils.get_headers(token)
2324

2425
assert "user-agent" in headers
2526
assert "accept" in headers
@@ -75,7 +76,8 @@ def test_get_details(n):
7576

7677
domains = details.domains
7778
assert domains["ip"] == "8.8.8.8"
78-
assert domains["total"] == 12988
79+
# NOTE: actual number changes too much
80+
assert "total" in domains
7981
assert len(domains["domains"]) == 5
8082

8183

0 commit comments

Comments
 (0)