Skip to content

Commit a1c5ad8

Browse files
committed
Improve type hinting
1 parent 382085e commit a1c5ad8

File tree

9 files changed

+297
-258
lines changed

9 files changed

+297
-258
lines changed

geoip2/_internal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class Model(metaclass=ABCMeta):
1010
def __eq__(self, other: object) -> bool:
1111
return isinstance(other, self.__class__) and self.to_dict() == other.to_dict()
1212

13-
def __ne__(self, other) -> bool:
13+
def __ne__(self, other: object) -> bool:
1414
return not self.__eq__(other)
1515

1616
# pylint: disable=too-many-branches

geoip2/database.py

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""The database reader for MaxMind MMDB files."""
22

3+
from __future__ import annotations
4+
35
import inspect
4-
import os
5-
from collections.abc import Sequence
6-
from typing import IO, Any, AnyStr, Optional, Union, cast
6+
from typing import IO, TYPE_CHECKING, AnyStr, cast
77

88
import maxminddb
99
from maxminddb import (
@@ -15,21 +15,30 @@
1515
MODE_MMAP_EXT,
1616
)
1717

18+
from maxminddb import InvalidDatabaseError
19+
1820
import geoip2
1921
import geoip2.errors
2022
import geoip2.models
21-
from geoip2.models import (
22-
ASN,
23-
ISP,
24-
AnonymousIP,
25-
AnonymousPlus,
26-
City,
27-
ConnectionType,
28-
Country,
29-
Domain,
30-
Enterprise,
31-
)
32-
from geoip2.types import IPAddress
23+
24+
if TYPE_CHECKING:
25+
import os
26+
from collections.abc import Sequence
27+
28+
from typing_extensions import Self
29+
30+
from geoip2.models import (
31+
ASN,
32+
ISP,
33+
AnonymousIP,
34+
AnonymousPlus,
35+
City,
36+
ConnectionType,
37+
Country,
38+
Domain,
39+
Enterprise,
40+
)
41+
from geoip2.types import IPAddress
3342

3443
__all__ = [
3544
"MODE_AUTO",
@@ -67,8 +76,8 @@ class Reader:
6776

6877
def __init__(
6978
self,
70-
fileish: Union[AnyStr, int, os.PathLike, IO],
71-
locales: Optional[Sequence[str]] = None,
79+
fileish: AnyStr | int | os.PathLike | IO,
80+
locales: Sequence[str] | None = None,
7281
mode: int = MODE_AUTO,
7382
) -> None:
7483
"""Create GeoIP2 Reader.
@@ -117,10 +126,10 @@ def __init__(
117126
self._db_type = self._db_reader.metadata().database_type
118127
self._locales = locales
119128

120-
def __enter__(self) -> "Reader":
129+
def __enter__(self) -> Self:
121130
return self
122131

123-
def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
132+
def __exit__(self, exc_type: object, exc_value: object, traceback: object) -> None:
124133
self.close()
125134

126135
def country(self, ip_address: IPAddress) -> Country:
@@ -249,10 +258,12 @@ def isp(self, ip_address: IPAddress) -> ISP:
249258
self._flat_model_for(geoip2.models.ISP, "GeoIP2-ISP", ip_address),
250259
)
251260

252-
def _get(self, database_type: str, ip_address: IPAddress) -> Any:
261+
def _get(self, database_type: str, ip_address: IPAddress) -> tuple[dict, int]:
253262
if database_type not in self._db_type:
254263
caller = inspect.stack()[2][3]
255-
msg = f"The {caller} method cannot be used with the {self._db_type} database"
264+
msg = (
265+
f"The {caller} method cannot be used with the {self._db_type} database"
266+
)
256267
raise TypeError(
257268
msg,
258269
)
@@ -264,14 +275,17 @@ def _get(self, database_type: str, ip_address: IPAddress) -> Any:
264275
str(ip_address),
265276
prefix_len,
266277
)
278+
if not isinstance(record, dict):
279+
msg = f"Expected record to be a dict but was f{type(record)}"
280+
raise InvalidDatabaseError(msg)
267281
return record, prefix_len
268282

269283
def _model_for(
270284
self,
271-
model_class: Union[type[Country], type[Enterprise], type[City]],
285+
model_class: type[City | Country | Enterprise],
272286
types: str,
273287
ip_address: IPAddress,
274-
) -> Union[Country, Enterprise, City]:
288+
) -> City | Country | Enterprise:
275289
(record, prefix_len) = self._get(types, ip_address)
276290
return model_class(
277291
self._locales,
@@ -282,16 +296,10 @@ def _model_for(
282296

283297
def _flat_model_for(
284298
self,
285-
model_class: Union[
286-
type[Domain],
287-
type[ISP],
288-
type[ConnectionType],
289-
type[ASN],
290-
type[AnonymousIP],
291-
],
299+
model_class: type[Domain | ISP | ConnectionType | ASN | AnonymousIP],
292300
types: str,
293301
ip_address: IPAddress,
294-
) -> Union[ConnectionType, ISP, AnonymousIP, Domain, ASN]:
302+
) -> ConnectionType | ISP | AnonymousIP | Domain | ASN:
295303
(record, prefix_len) = self._get(types, ip_address)
296304
return model_class(ip_address, prefix_len=prefix_len, **record)
297305

geoip2/errors.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Typed errors thrown by this library."""
22

3+
from __future__ import annotations
4+
35
import ipaddress
4-
from typing import Optional, Union
56

67

78
class GeoIP2Error(RuntimeError):
@@ -16,24 +17,24 @@ class GeoIP2Error(RuntimeError):
1617
class AddressNotFoundError(GeoIP2Error):
1718
"""The address you were looking up was not found."""
1819

19-
ip_address: Optional[str]
20+
ip_address: str | None
2021
"""The IP address used in the lookup. This is only available for database
2122
lookups.
2223
"""
23-
_prefix_len: Optional[int]
24+
_prefix_len: int | None
2425

2526
def __init__(
2627
self,
2728
message: str,
28-
ip_address: Optional[str] = None,
29-
prefix_len: Optional[int] = None,
29+
ip_address: str | None = None,
30+
prefix_len: int | None = None,
3031
) -> None:
3132
super().__init__(message)
3233
self.ip_address = ip_address
3334
self._prefix_len = prefix_len
3435

3536
@property
36-
def network(self) -> Optional[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]:
37+
def network(self) -> ipaddress.IPv4Network | ipaddress.IPv6Network | None:
3738
"""The network associated with the error.
3839
3940
In particular, this is the largest network where no address would be
@@ -42,7 +43,8 @@ def network(self) -> Optional[Union[ipaddress.IPv4Network, ipaddress.IPv6Network
4243
if self.ip_address is None or self._prefix_len is None:
4344
return None
4445
return ipaddress.ip_network(
45-
f"{self.ip_address}/{self._prefix_len}", strict=False,
46+
f"{self.ip_address}/{self._prefix_len}",
47+
strict=False,
4648
)
4749

4850

@@ -58,19 +60,19 @@ class HTTPError(GeoIP2Error):
5860
5961
"""
6062

61-
http_status: Optional[int]
63+
http_status: int | None
6264
"""The HTTP status code returned"""
63-
uri: Optional[str]
65+
uri: str | None
6466
"""The URI queried"""
65-
decoded_content: Optional[str]
67+
decoded_content: str | None
6668
"""The decoded response content"""
6769

6870
def __init__(
6971
self,
7072
message: str,
71-
http_status: Optional[int] = None,
72-
uri: Optional[str] = None,
73-
decoded_content: Optional[str] = None,
73+
http_status: int | None = None,
74+
uri: str | None = None,
75+
decoded_content: str | None = None,
7476
) -> None:
7577
super().__init__(message)
7678
self.http_status = http_status

0 commit comments

Comments
 (0)