Skip to content

Commit 312386e

Browse files
committed
Improve typing
1 parent 58e0021 commit 312386e

File tree

5 files changed

+82
-56
lines changed

5 files changed

+82
-56
lines changed

maxminddb/decoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
"""Decoder for the MaxMind DB data section."""
22

33
import struct
4-
from typing import Union, cast
4+
from typing import ClassVar, Union, cast
55

66
try:
77
# pylint: disable=unused-import
88
import mmap
99
except ImportError:
1010
# pylint: disable=invalid-name
11-
mmap = None # type: ignore
11+
mmap = None # type: ignore[assignment]
1212

1313

1414
from maxminddb.errors import InvalidDatabaseError
@@ -116,7 +116,7 @@ def _decode_utf8_string(self, size: int, offset: int) -> tuple[str, int]:
116116
new_offset = offset + size
117117
return self._buffer[offset:new_offset].decode("utf-8"), new_offset
118118

119-
_type_decoder = {
119+
_type_decoder: ClassVar = {
120120
1: _decode_pointer,
121121
2: _decode_utf8_string,
122122
3: _decode_double,

maxminddb/file.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""For internal use only. It provides a slice-like file reader."""
22

3+
from __future__ import annotations
4+
35
import os
4-
from typing import Union
6+
from typing import overload
57

68
try:
79
# pylint: disable=no-name-in-module
810
from multiprocessing import Lock
911
except ImportError:
10-
from threading import Lock # type: ignore
12+
from threading import Lock # type: ignore[assignment]
1113

1214

1315
class FileBuffer:
@@ -20,11 +22,18 @@ def __init__(self, database: str) -> None:
2022
if not hasattr(os, "pread"):
2123
self._lock = Lock()
2224

23-
def __getitem__(self, key: Union[slice, int]):
24-
if isinstance(key, slice):
25-
return self._read(key.stop - key.start, key.start)
26-
if isinstance(key, int):
27-
return self._read(1, key)[0]
25+
@overload
26+
def __getitem__(self, index: int) -> int: ...
27+
28+
@overload
29+
def __getitem__(self, index: slice) -> bytes: ...
30+
31+
def __getitem__(self, index: slice | int) -> bytes | int:
32+
"""Get item by index."""
33+
if isinstance(index, slice):
34+
return self._read(index.stop - index.start, index.start)
35+
if isinstance(index, int):
36+
return self._read(1, index)[0]
2837
msg = "Invalid argument type."
2938
raise TypeError(msg)
3039

@@ -43,12 +52,12 @@ def close(self) -> None:
4352
"""Close file."""
4453
self._handle.close()
4554

46-
if hasattr(os, "pread"):
55+
if hasattr(os, "pread"): # type: ignore[attr-defined]
4756

4857
def _read(self, buffersize: int, offset: int) -> bytes:
4958
"""Read that uses pread."""
5059
# pylint: disable=no-member
51-
return os.pread(self._handle.fileno(), buffersize, offset) # type: ignore
60+
return os.pread(self._handle.fileno(), buffersize, offset) # type: ignore[]
5261

5362
else:
5463

maxminddb/reader.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import mmap
55
except ImportError:
66
# pylint: disable=invalid-name
7-
mmap = None # type: ignore
7+
mmap = None # type: ignore[assignment]
88

99
import ipaddress
1010
import struct
@@ -58,24 +58,24 @@ def __init__(
5858
"""
5959
filename: Any
6060
if (mode == MODE_AUTO and mmap) or mode == MODE_MMAP:
61-
with open(database, "rb") as db_file: # type: ignore
61+
with open(database, "rb") as db_file: # type: ignore[arg-type]
6262
self._buffer = mmap.mmap(db_file.fileno(), 0, access=mmap.ACCESS_READ)
6363
self._buffer_size = self._buffer.size()
6464
filename = database
6565
elif mode in (MODE_AUTO, MODE_FILE):
66-
self._buffer = FileBuffer(database) # type: ignore
66+
self._buffer = FileBuffer(database) # type: ignore[arg-type]
6767
self._buffer_size = self._buffer.size()
6868
filename = database
6969
elif mode == MODE_MEMORY:
70-
with open(database, "rb") as db_file: # type: ignore
70+
with open(database, "rb") as db_file: # type: ignore[arg-type]
7171
buf = db_file.read()
7272
self._buffer = buf
7373
self._buffer_size = len(buf)
7474
filename = database
7575
elif mode == MODE_FD:
76-
self._buffer = database.read() # type: ignore
77-
self._buffer_size = len(self._buffer) # type: ignore
78-
filename = database.name # type: ignore
76+
self._buffer = database.read() # type: ignore[union-attr]
77+
self._buffer_size = len(self._buffer) # type: ignore[arg-type]
78+
filename = database.name # type: ignore[union-attr]
7979
else:
8080
msg = (
8181
f"Unsupported open mode ({mode}). Only MODE_AUTO, MODE_FILE, "
@@ -185,7 +185,7 @@ def get_with_prefix_len(
185185
def __iter__(self) -> Iterator:
186186
return self._generate_children(0, 0, 0)
187187

188-
def _generate_children(self, node, depth, ip_acc) -> Iterator:
188+
def _generate_children(self, node: int, depth: int, ip_acc: int) -> Iterator:
189189
if ip_acc != 0 and node == self._ipv4_start:
190190
# Skip nodes aliased to IPv4
191191
return

tests/decoder_test.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1-
#!/usr/bin/env python
1+
from __future__ import annotations
22

33
import mmap
44
import unittest
5+
from typing import TYPE_CHECKING, Any, ClassVar
56

67
from maxminddb.decoder import Decoder
78

9+
if TYPE_CHECKING:
10+
from _typeshed import SizedBuffer
11+
812

913
class TestDecoder(unittest.TestCase):
1014
def test_arrays(self) -> None:
@@ -101,7 +105,7 @@ def test_pointer(self) -> None:
101105
}
102106
self.validate_type_decoding("pointers", pointers)
103107

104-
strings = {
108+
strings: ClassVar = {
105109
b"\x40": "",
106110
b"\x41\x31": "1",
107111
b"\x43\xe4\xba\xba": "人",
@@ -165,7 +169,7 @@ def test_uint32(self) -> None:
165169
}
166170
self.validate_type_decoding("uint32", uint32)
167171

168-
def generate_large_uint(self, bits) -> dict:
172+
def generate_large_uint(self, bits: int) -> dict:
169173
ctrl_byte = b"\x02" if bits == 64 else b"\x03"
170174
uints = {
171175
b"\x00" + ctrl_byte: 0,
@@ -174,8 +178,8 @@ def generate_large_uint(self, bits) -> dict:
174178
}
175179
for power in range(bits // 8 + 1):
176180
expected = 2 ** (8 * power) - 1
177-
input = bytes([power]) + ctrl_byte + (b"\xff" * power)
178-
uints[input] = expected
181+
input_value = bytes([power]) + ctrl_byte + (b"\xff" * power)
182+
uints[input_value] = expected
179183
return uints
180184

181185
def test_uint64(self) -> None:
@@ -184,25 +188,31 @@ def test_uint64(self) -> None:
184188
def test_uint128(self) -> None:
185189
self.validate_type_decoding("uint128", self.generate_large_uint(128))
186190

187-
def validate_type_decoding(self, type, tests) -> None:
188-
for input, expected in tests.items():
189-
self.check_decoding(type, input, expected)
190-
191-
def check_decoding(self, type, input, expected, name=None) -> None:
191+
def validate_type_decoding(self, data_type: str, tests: dict) -> None:
192+
for input_value, expected in tests.items():
193+
self.check_decoding(data_type, input_value, expected)
194+
195+
def check_decoding(
196+
self,
197+
data_type: str,
198+
input_value: SizedBuffer,
199+
expected: Any, # noqa: ANN401
200+
name: str | None = None,
201+
) -> None:
192202
name = name or expected
193-
db = mmap.mmap(-1, len(input))
194-
db.write(input)
203+
db = mmap.mmap(-1, len(input_value))
204+
db.write(input_value)
195205

196206
decoder = Decoder(db, pointer_test=True)
197207
(
198208
actual,
199209
_,
200210
) = decoder.decode(0)
201211

202-
if type in ("float", "double"):
203-
self.assertAlmostEqual(expected, actual, places=3, msg=type)
212+
if data_type in ("float", "double"):
213+
self.assertAlmostEqual(expected, actual, places=3, msg=data_type)
204214
else:
205-
self.assertEqual(expected, actual, type)
215+
self.assertEqual(expected, actual, data_type)
206216

207217
def test_real_pointers(self) -> None:
208218
with open("tests/data/test-data/maps-with-pointers.raw", "r+b") as db_file:

tests/reader_test.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
try:
1515
import maxminddb.extension
1616
except ImportError:
17-
maxminddb.extension = None # type: ignore
17+
maxminddb.extension = None # type: ignore[assignment]
1818

1919
from maxminddb import InvalidDatabaseError, open_database
2020
from maxminddb.const import (
@@ -28,7 +28,7 @@
2828
from maxminddb.reader import Reader
2929

3030

31-
def get_reader_from_file_descriptor(filepath, mode) -> Reader:
31+
def get_reader_from_file_descriptor(filepath: str, mode: int) -> Reader:
3232
"""Patches open_database() for class TestFDReader()."""
3333
if mode == MODE_FD:
3434
with open(filepath, "rb") as mmdb_fh:
@@ -53,7 +53,7 @@ class BaseTestReader(unittest.TestCase):
5353
if os.name != "nt":
5454
mp = multiprocessing.get_context("fork")
5555

56-
def ipf(self, ip) -> Union[ipaddress.IPv4Address, ipaddress.IPv6Address]:
56+
def ipf(self, ip: str) -> Union[ipaddress.IPv4Address, ipaddress.IPv6Address, str]:
5757
if self.use_ip_objects:
5858
return ipaddress.ip_address(ip)
5959
return ip
@@ -179,7 +179,9 @@ def test_get_with_prefix_len(self) -> None:
179179
"tests/data/test-data/" + cast("str", test["file_name"]),
180180
self.mode,
181181
) as reader:
182-
(record, prefix_len) = reader.get_with_prefix_len(cast("str", test["ip"]))
182+
(record, prefix_len) = reader.get_with_prefix_len(
183+
cast("str", test["ip"])
184+
)
183185

184186
self.assertEqual(
185187
prefix_len,
@@ -314,8 +316,8 @@ def test_opening_path(self) -> None:
314316
self.assertEqual(reader.metadata().database_type, "MaxMind DB Decoder Test")
315317

316318
def test_no_extension_exception(self) -> None:
317-
real_extension = maxminddb._extension
318-
maxminddb._extension = None # type: ignore
319+
real_extension = maxminddb._extension # noqa: SLF001
320+
maxminddb._extension = None # type: ignore[assignment] # noqa: SLF001
319321
with self.assertRaisesRegex(
320322
ValueError,
321323
"MODE_MMAP_EXT requires the maxminddb.extension module to be available",
@@ -375,11 +377,11 @@ def test_database_with_invalid_utf8_key(self) -> None:
375377

376378
def test_too_many_constructor_args(self) -> None:
377379
with self.assertRaises(TypeError):
378-
self.reader_class("README.md", self.mode, 1) # type: ignore
380+
self.reader_class("README.md", self.mode, 1) # type: ignore[arg-type,call-arg]
379381

380382
def test_bad_constructor_mode(self) -> None:
381383
with self.assertRaisesRegex(ValueError, r"Unsupported open mode \(100\)"):
382-
self.reader_class("README.md", mode=100) # type: ignore
384+
self.reader_class("README.md", mode=100) # type: ignore[arg-type]
383385

384386
def test_no_constructor_args(self) -> None:
385387
with self.assertRaisesRegex(
@@ -389,15 +391,15 @@ def test_no_constructor_args(self) -> None:
389391
r"takes at least 2 arguments|"
390392
r"function missing required argument \'database\' \(pos 1\)",
391393
):
392-
self.reader_class() # type: ignore
394+
self.reader_class() # type: ignore[call-arg]
393395

394396
def test_too_many_get_args(self) -> None:
395397
reader = open_database(
396398
"tests/data/test-data/MaxMind-DB-test-decoder.mmdb",
397399
self.mode,
398400
)
399401
with self.assertRaises(TypeError):
400-
reader.get(self.ipf("1.1.1.1"), "blah") # type: ignore
402+
reader.get(self.ipf("1.1.1.1"), "blah") # type: ignore[call-arg]
401403
reader.close()
402404

403405
def test_no_get_args(self) -> None:
@@ -406,7 +408,7 @@ def test_no_get_args(self) -> None:
406408
self.mode,
407409
)
408410
with self.assertRaises(TypeError):
409-
reader.get() # type: ignore
411+
reader.get() # type: ignore[call-arg]
410412
reader.close()
411413

412414
def test_incorrect_get_arg_type(self) -> None:
@@ -415,7 +417,7 @@ def test_incorrect_get_arg_type(self) -> None:
415417
TypeError,
416418
"argument 1 must be a string or ipaddress object",
417419
):
418-
reader.get(1) # type: ignore
420+
reader.get(1) # type: ignore[arg-type]
419421
reader.close()
420422

421423
def test_metadata_args(self) -> None:
@@ -424,7 +426,7 @@ def test_metadata_args(self) -> None:
424426
self.mode,
425427
)
426428
with self.assertRaises(TypeError):
427-
reader.metadata("blah") # type: ignore
429+
reader.metadata("blah") # type: ignore[call-arg]
428430
reader.close()
429431

430432
def test_metadata_unknown_attribute(self) -> None:
@@ -437,7 +439,7 @@ def test_metadata_unknown_attribute(self) -> None:
437439
AttributeError,
438440
"'Metadata' object has no attribute 'blah'",
439441
):
440-
metadata.blah # type: ignore
442+
metadata.blah # type: ignore[attr-defined] # noqa: B018
441443
reader.close()
442444

443445
def test_close(self) -> None:
@@ -547,7 +549,7 @@ def lookup(pipe) -> None:
547549
except:
548550
pipe.send(0)
549551
finally:
550-
if worker_class is self.mp.Process: # type: ignore
552+
if worker_class is self.mp.Process: # type: ignore[attr-defined]
551553
reader.close()
552554
pipe.close()
553555

@@ -560,11 +562,16 @@ def lookup(pipe) -> None:
560562

561563
reader.close()
562564

563-
count = sum([p.recv() for (p, c) in pipes])
565+
count = sum([p.recv() for (p, _) in pipes])
564566

565567
self.assertEqual(count, 32, "expected number of successful lookups")
566568

567-
def _check_metadata(self, reader, ip_version, record_size) -> None:
569+
def _check_metadata(
570+
self,
571+
reader: Reader,
572+
ip_version: int,
573+
record_size: int,
574+
) -> None:
568575
metadata = reader.metadata()
569576

570577
self.assertEqual(2, metadata.binary_format_major_version, "major version")
@@ -582,7 +589,7 @@ def _check_metadata(self, reader, ip_version, record_size) -> None:
582589

583590
self.assertEqual(metadata.record_size, record_size)
584591

585-
def _check_ip_v4(self, reader, file_name) -> None:
592+
def _check_ip_v4(self, reader: Reader, file_name: str) -> None:
586593
for i in range(6):
587594
address = "1.1.1." + str(pow(2, i))
588595
self.assertEqual(
@@ -612,7 +619,7 @@ def _check_ip_v4(self, reader, file_name) -> None:
612619
for ip in ["1.1.1.33", "255.254.253.123"]:
613620
self.assertIsNone(reader.get(self.ipf(ip)))
614621

615-
def _check_ip_v6(self, reader, file_name) -> None:
622+
def _check_ip_v6(self, reader: Reader, file_name: str) -> None:
616623
subnets = ["::1:ffff:ffff", "::2:0:0", "::2:0:40", "::2:0:50", "::2:0:58"]
617624

618625
for address in subnets:
@@ -645,10 +652,10 @@ def _check_ip_v6(self, reader, file_name) -> None:
645652

646653

647654
def has_maxminddb_extension() -> bool:
648-
return maxminddb.extension and hasattr(
655+
return maxminddb.extension is not None and hasattr(
649656
maxminddb.extension,
650657
"Reader",
651-
) # type: ignore
658+
)
652659

653660

654661
@unittest.skipIf(

0 commit comments

Comments
 (0)