11"""The database reader for MaxMind MMDB files."""
22
3+ from __future__ import annotations
4+
35import inspect
4- import os
5- from collections .abc import Sequence
6- from typing import IO , Any , AnyStr , Optional , Union , cast
6+ from typing import IO , TYPE_CHECKING , AnyStr , cast
77
88import maxminddb
99from maxminddb import (
1515 MODE_MMAP_EXT ,
1616)
1717
18+ from maxminddb import InvalidDatabaseError
19+
1820import geoip2
1921import geoip2 .errors
2022import geoip2 .models
21- from geoip2 .models import (
22- ASN ,
23- ISP ,
24- AnonymousIP ,
25- AnonymousPlus ,
26- City ,
27- ConnectionType ,
28- Country ,
29- Domain ,
30- Enterprise ,
31- )
32- from geoip2 .types import IPAddress
23+
24+ if TYPE_CHECKING :
25+ import os
26+ from collections .abc import Sequence
27+
28+ from typing_extensions import Self
29+
30+ from geoip2 .models import (
31+ ASN ,
32+ ISP ,
33+ AnonymousIP ,
34+ AnonymousPlus ,
35+ City ,
36+ ConnectionType ,
37+ Country ,
38+ Domain ,
39+ Enterprise ,
40+ )
41+ from geoip2 .types import IPAddress
3342
3443__all__ = [
3544 "MODE_AUTO" ,
@@ -67,8 +76,8 @@ class Reader:
6776
6877 def __init__ (
6978 self ,
70- fileish : Union [ AnyStr , int , os .PathLike , IO ] ,
71- locales : Optional [ Sequence [str ]] = None ,
79+ fileish : AnyStr | int | os .PathLike | IO ,
80+ locales : Sequence [str ] | None = None ,
7281 mode : int = MODE_AUTO ,
7382 ) -> None :
7483 """Create GeoIP2 Reader.
@@ -117,10 +126,10 @@ def __init__(
117126 self ._db_type = self ._db_reader .metadata ().database_type
118127 self ._locales = locales
119128
120- def __enter__ (self ) -> "Reader" :
129+ def __enter__ (self ) -> Self :
121130 return self
122131
123- def __exit__ (self , exc_type : None , exc_value : None , traceback : None ) -> None :
132+ def __exit__ (self , exc_type : object , exc_value : object , traceback : object ) -> None :
124133 self .close ()
125134
126135 def country (self , ip_address : IPAddress ) -> Country :
@@ -249,10 +258,12 @@ def isp(self, ip_address: IPAddress) -> ISP:
249258 self ._flat_model_for (geoip2 .models .ISP , "GeoIP2-ISP" , ip_address ),
250259 )
251260
252- def _get (self , database_type : str , ip_address : IPAddress ) -> Any :
261+ def _get (self , database_type : str , ip_address : IPAddress ) -> tuple [ dict , int ] :
253262 if database_type not in self ._db_type :
254263 caller = inspect .stack ()[2 ][3 ]
255- msg = f"The { caller } method cannot be used with the { self ._db_type } database"
264+ msg = (
265+ f"The { caller } method cannot be used with the { self ._db_type } database"
266+ )
256267 raise TypeError (
257268 msg ,
258269 )
@@ -264,14 +275,17 @@ def _get(self, database_type: str, ip_address: IPAddress) -> Any:
264275 str (ip_address ),
265276 prefix_len ,
266277 )
278+ if not isinstance (record , dict ):
279+ msg = f"Expected record to be a dict but was f{ type (record )} "
280+ raise InvalidDatabaseError (msg )
267281 return record , prefix_len
268282
269283 def _model_for (
270284 self ,
271- model_class : Union [ type [Country ], type [ Enterprise ], type [ City ] ],
285+ model_class : type [City | Country | Enterprise ],
272286 types : str ,
273287 ip_address : IPAddress ,
274- ) -> Union [ Country , Enterprise , City ] :
288+ ) -> City | Country | Enterprise :
275289 (record , prefix_len ) = self ._get (types , ip_address )
276290 return model_class (
277291 self ._locales ,
@@ -282,16 +296,10 @@ def _model_for(
282296
283297 def _flat_model_for (
284298 self ,
285- model_class : Union [
286- type [Domain ],
287- type [ISP ],
288- type [ConnectionType ],
289- type [ASN ],
290- type [AnonymousIP ],
291- ],
299+ model_class : type [Domain | ISP | ConnectionType | ASN | AnonymousIP ],
292300 types : str ,
293301 ip_address : IPAddress ,
294- ) -> Union [ ConnectionType , ISP , AnonymousIP , Domain , ASN ] :
302+ ) -> ConnectionType | ISP | AnonymousIP | Domain | ASN :
295303 (record , prefix_len ) = self ._get (types , ip_address )
296304 return model_class (ip_address , prefix_len = prefix_len , ** record )
297305
0 commit comments