|
| 1 | +""" |
| 2 | +Main API client asynchronous handler for fetching data from the IPinfo service. |
| 3 | +""" |
| 4 | + |
| 5 | +from ipaddress import IPv4Address, IPv6Address |
| 6 | +import json |
| 7 | +import os |
| 8 | +import sys |
| 9 | + |
| 10 | +import requests |
| 11 | + |
| 12 | +from .cache.default import DefaultCache |
| 13 | +from .details import Details |
| 14 | +from .exceptions import RequestQuotaExceededError |
| 15 | + |
| 16 | + |
| 17 | +class AsyncHandler: |
| 18 | + """ |
| 19 | + Allows client to request data for specified IP address asynchronously. |
| 20 | + Instantiates and maintains access to cache. |
| 21 | + """ |
| 22 | + |
| 23 | + API_URL = "https://ipinfo.io" |
| 24 | + CACHE_MAXSIZE = 4096 |
| 25 | + CACHE_TTL = 60 * 60 * 24 |
| 26 | + COUNTRY_FILE_DEFAULT = "countries.json" |
| 27 | + REQUEST_TIMEOUT_DEFAULT = 2 |
| 28 | + |
| 29 | + def __init__(self, access_token=None, **kwargs): |
| 30 | + """Initialize the Handler object with country name list and the cache initialized.""" |
| 31 | + self.access_token = access_token |
| 32 | + self.countries = self._read_country_names(kwargs.get("countries_file")) |
| 33 | + self.request_options = kwargs.get("request_options", {}) |
| 34 | + if "timeout" not in self.request_options: |
| 35 | + self.request_options["timeout"] = self.REQUEST_TIMEOUT_DEFAULT |
| 36 | + |
| 37 | + if "cache" in kwargs: |
| 38 | + self.cache = kwargs["cache"] |
| 39 | + else: |
| 40 | + cache_options = kwargs.get("cache_options", {}) |
| 41 | + if "maxsize" not in cache_options: |
| 42 | + cache_options["maxsize"] = self.CACHE_MAXSIZE |
| 43 | + if "ttl" not in cache_options: |
| 44 | + cache_options["ttl"] = self.CACHE_TTL |
| 45 | + self.cache = DefaultCache(**cache_options) |
| 46 | + |
| 47 | + async def getDetails(self, ip_address=None): |
| 48 | + """Get details for specified IP address as a Details object.""" |
| 49 | + # If the supplied IP address uses the objects defined in the built-in |
| 50 | + # module ipaddress, extract the appropriate string notation before |
| 51 | + # formatting the URL. |
| 52 | + if isinstance(ip_address, IPv4Address) or isinstance(ip_address, IPv6Address): |
| 53 | + ip_address = ip_address.exploded |
| 54 | + |
| 55 | + if ip_address in self.cache: |
| 56 | + return Details(self.cache[ip_address]) |
| 57 | + |
| 58 | + # not in cache; get result, format it, and put in cache. |
| 59 | + url = self.API_URL |
| 60 | + if ip_address: |
| 61 | + url += "/" + ip_address |
| 62 | + response = requests.get( |
| 63 | + url, headers=self._get_headers(), **self.request_options |
| 64 | + ) |
| 65 | + if response.status_code == 429: |
| 66 | + raise RequestQuotaExceededError() |
| 67 | + response.raise_for_status() |
| 68 | + raw_details = response.json() |
| 69 | + self._format_details(raw_details) |
| 70 | + self.cache[ip_address] = raw_details |
| 71 | + return Details(raw_details) |
| 72 | + |
| 73 | + async def getBatchDetails(self, ip_addresses): |
| 74 | + """Get details for a batch of IP addresses at once.""" |
| 75 | + result = {} |
| 76 | + |
| 77 | + # Pre-populate with anything we've got in the cache, and keep around |
| 78 | + # the IPs not in the cache. |
| 79 | + lookup_addresses = [] |
| 80 | + for ip_address in ip_addresses: |
| 81 | + # If the supplied IP address uses the objects defined in the |
| 82 | + # built-in module ipaddress extract the appropriate string notation |
| 83 | + # before formatting the URL. |
| 84 | + if isinstance(ip_address, IPv4Address) or isinstance(ip_address, IPv6Address): |
| 85 | + ip_address = ip_address.exploded |
| 86 | + |
| 87 | + if ip_address in self.cache: |
| 88 | + result[ip_address] = self.cache[ip_address] |
| 89 | + else: |
| 90 | + lookup_addresses.append(ip_address) |
| 91 | + |
| 92 | + # all in cache - return early. |
| 93 | + if len(lookup_addresses) == 0: |
| 94 | + return result |
| 95 | + |
| 96 | + # Do the lookup |
| 97 | + url = self.API_URL + "/batch" |
| 98 | + headers = self._get_headers() |
| 99 | + headers["content-type"] = "application/json" |
| 100 | + response = requests.post( |
| 101 | + url, json=lookup_addresses, headers=headers, **self.request_options |
| 102 | + ) |
| 103 | + if response.status_code == 429: |
| 104 | + raise RequestQuotaExceededError() |
| 105 | + response.raise_for_status() |
| 106 | + |
| 107 | + # Format & fill up cache |
| 108 | + json_response = response.json() |
| 109 | + for ip_address, details in json_response.items(): |
| 110 | + if isinstance(details, dict): |
| 111 | + self._format_details(details) |
| 112 | + self.cache[ip_address] = details |
| 113 | + |
| 114 | + # Merge cached results with new lookup |
| 115 | + result.update(json_response) |
| 116 | + |
| 117 | + return result |
| 118 | + |
| 119 | + def _get_headers(self): |
| 120 | + """Built headers for request to IPinfo API.""" |
| 121 | + headers = { |
| 122 | + "user-agent": "IPinfoClient/Python{version}/3.0.0".format( |
| 123 | + version=sys.version_info[0] |
| 124 | + ), |
| 125 | + "accept": "application/json", |
| 126 | + } |
| 127 | + |
| 128 | + if self.access_token: |
| 129 | + headers["authorization"] = "Bearer {}".format(self.access_token) |
| 130 | + |
| 131 | + return headers |
| 132 | + |
| 133 | + def _format_details(self, details): |
| 134 | + details["country_name"] = self.countries.get(details.get("country")) |
| 135 | + details["latitude"], details["longitude"] = self._read_coords( |
| 136 | + details.get("loc") |
| 137 | + ) |
| 138 | + |
| 139 | + def _read_coords(self, location): |
| 140 | + lat, lon = None, None |
| 141 | + coords = tuple(location.split(",")) if location else "" |
| 142 | + if len(coords) == 2 and coords[0] and coords[1]: |
| 143 | + lat, lon = coords[0], coords[1] |
| 144 | + return lat, lon |
| 145 | + |
| 146 | + def _read_country_names(self, countries_file=None): |
| 147 | + """Read list of countries from specified country file or default file.""" |
| 148 | + if not countries_file: |
| 149 | + countries_file = os.path.join( |
| 150 | + os.path.dirname(__file__), self.COUNTRY_FILE_DEFAULT |
| 151 | + ) |
| 152 | + with open(countries_file) as f: |
| 153 | + countries_json = f.read() |
| 154 | + |
| 155 | + return json.loads(countries_json) |
0 commit comments