Skip to content

Commit 72a99db

Browse files
Bhargavasomucburgdorf
authored andcommitted
Enable complete type hinting for eth.db
1 parent 64ac6a4 commit 72a99db

File tree

10 files changed

+124
-86
lines changed

10 files changed

+124
-86
lines changed

eth/beacon/db/chain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def _get_canonical_head(cls, db: BaseDB) -> BaseBeaconBlock:
243243
canonical_head_hash = db[SchemaV1.make_canonical_head_hash_lookup_key()]
244244
except KeyError:
245245
raise CanonicalHeadNotFound("No canonical head set for this chain")
246-
return cls._get_block_by_hash(db, canonical_head_hash)
246+
return cls._get_block_by_hash(db, Hash32(canonical_head_hash))
247247

248248
def get_block_by_hash(self, block_hash: Hash32) -> BaseBeaconBlock:
249249
return self._get_block_by_hash(self.db, block_hash)

eth/db/account.py

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
BLANK_ROOT_HASH,
2929
EMPTY_SHA3,
3030
)
31+
from eth.db.backends.base import (
32+
BaseDB,
33+
)
3134
from eth.db.batch import (
3235
BatchDB,
3336
)
@@ -71,7 +74,7 @@ def __init__(self) -> None:
7174

7275
@property
7376
@abstractmethod
74-
def state_root(self):
77+
def state_root(self) -> Hash32:
7578
raise NotImplementedError("Must be implemented by subclasses")
7679

7780
@abstractmethod
@@ -82,11 +85,11 @@ def has_root(self, state_root: bytes) -> bool:
8285
# Storage
8386
#
8487
@abstractmethod
85-
def get_storage(self, address, slot):
88+
def get_storage(self, address: Address, slot: int) -> int:
8689
raise NotImplementedError("Must be implemented by subclasses")
8790

8891
@abstractmethod
89-
def set_storage(self, address, slot, value):
92+
def set_storage(self, address: Address, slot: int, value: int) -> None:
9093
raise NotImplementedError("Must be implemented by subclasses")
9194

9295
#
@@ -104,40 +107,40 @@ def set_nonce(self, address: Address, nonce: int) -> None:
104107
# Balance
105108
#
106109
@abstractmethod
107-
def get_balance(self, address):
110+
def get_balance(self, address: Address) -> int:
108111
raise NotImplementedError("Must be implemented by subclasses")
109112

110113
@abstractmethod
111-
def set_balance(self, address, balance):
114+
def set_balance(self, address: Address, balance: int) -> None:
112115
raise NotImplementedError("Must be implemented by subclasses")
113116

114-
def delta_balance(self, address, delta):
117+
def delta_balance(self, address: Address, delta: int) -> None:
115118
self.set_balance(address, self.get_balance(address) + delta)
116119

117120
#
118121
# Code
119122
#
120123
@abstractmethod
121-
def set_code(self, address, code):
124+
def set_code(self, address: Address, code: bytes) -> None:
122125
raise NotImplementedError("Must be implemented by subclasses")
123126

124127
@abstractmethod
125-
def get_code(self, address):
128+
def get_code(self, address: Address) -> bytes:
126129
raise NotImplementedError("Must be implemented by subclasses")
127130

128131
@abstractmethod
129-
def get_code_hash(self, address):
132+
def get_code_hash(self, address: Address) -> Hash32:
130133
raise NotImplementedError("Must be implemented by subclasses")
131134

132135
@abstractmethod
133-
def delete_code(self, address):
136+
def delete_code(self, address: Address) -> None:
134137
raise NotImplementedError("Must be implemented by subclasses")
135138

136139
#
137140
# Account Methods
138141
#
139142
@abstractmethod
140-
def account_is_empty(self, address):
143+
def account_is_empty(self, address: Address) -> bool:
141144
raise NotImplementedError("Must be implemented by subclass")
142145

143146
#
@@ -177,7 +180,7 @@ class AccountDB(BaseAccountDB):
177180

178181
logger = cast(TraceLogger, logging.getLogger('eth.db.account.AccountDB'))
179182

180-
def __init__(self, db, state_root=BLANK_ROOT_HASH):
183+
def __init__(self, db: BaseDB, state_root: Hash32=BLANK_ROOT_HASH) -> None:
181184
r"""
182185
Internal implementation details (subject to rapid change):
183186
Database entries go through several pipes, like so...
@@ -225,11 +228,11 @@ def __init__(self, db, state_root=BLANK_ROOT_HASH):
225228
self._journaltrie = JournalDB(self._trie_cache)
226229

227230
@property
228-
def state_root(self):
231+
def state_root(self) -> Hash32:
229232
return self._trie.root_hash
230233

231234
@state_root.setter
232-
def state_root(self, value):
235+
def state_root(self, value: Hash32) -> None:
233236
self._trie_cache.reset_cache()
234237
self._trie.root_hash = value
235238

@@ -239,7 +242,7 @@ def has_root(self, state_root: bytes) -> bool:
239242
#
240243
# Storage
241244
#
242-
def get_storage(self, address, slot, from_journal=True):
245+
def get_storage(self, address: Address, slot: int, from_journal: bool=True) -> int:
243246
validate_canonical_address(address, title="Storage Address")
244247
validate_uint256(slot, title="Storage Slot")
245248

@@ -254,7 +257,7 @@ def get_storage(self, address, slot, from_journal=True):
254257
else:
255258
return 0
256259

257-
def set_storage(self, address, slot, value):
260+
def set_storage(self, address: Address, slot: int, value: int) -> None:
258261
validate_uint256(value, title="Storage Value")
259262
validate_uint256(slot, title="Storage Slot")
260263
validate_canonical_address(address, title="Storage Address")
@@ -272,7 +275,7 @@ def set_storage(self, address, slot, value):
272275

273276
self._set_account(address, account.copy(storage_root=storage.root_hash))
274277

275-
def delete_storage(self, address):
278+
def delete_storage(self, address: Address) -> None:
276279
validate_canonical_address(address, title="Storage Address")
277280

278281
account = self._get_account(address)
@@ -281,13 +284,13 @@ def delete_storage(self, address):
281284
#
282285
# Balance
283286
#
284-
def get_balance(self, address):
287+
def get_balance(self, address: Address) -> int:
285288
validate_canonical_address(address, title="Storage Address")
286289

287290
account = self._get_account(address)
288291
return account.balance
289292

290-
def set_balance(self, address, balance):
293+
def set_balance(self, address: Address, balance: int) -> None:
291294
validate_canonical_address(address, title="Storage Address")
292295
validate_uint256(balance, title="Account Balance")
293296

@@ -297,35 +300,35 @@ def set_balance(self, address, balance):
297300
#
298301
# Nonce
299302
#
300-
def get_nonce(self, address):
303+
def get_nonce(self, address: Address) -> int:
301304
validate_canonical_address(address, title="Storage Address")
302305

303306
account = self._get_account(address)
304307
return account.nonce
305308

306-
def set_nonce(self, address, nonce):
309+
def set_nonce(self, address: Address, nonce: int) -> None:
307310
validate_canonical_address(address, title="Storage Address")
308311
validate_uint256(nonce, title="Nonce")
309312

310313
account = self._get_account(address)
311314
self._set_account(address, account.copy(nonce=nonce))
312315

313-
def increment_nonce(self, address):
316+
def increment_nonce(self, address: Address) -> None:
314317
current_nonce = self.get_nonce(address)
315318
self.set_nonce(address, current_nonce + 1)
316319

317320
#
318321
# Code
319322
#
320-
def get_code(self, address):
323+
def get_code(self, address: Address) -> bytes:
321324
validate_canonical_address(address, title="Storage Address")
322325

323326
try:
324327
return self._journaldb[self.get_code_hash(address)]
325328
except KeyError:
326329
return b""
327330

328-
def set_code(self, address, code):
331+
def set_code(self, address: Address, code: bytes) -> None:
329332
validate_canonical_address(address, title="Storage Address")
330333
validate_is_bytes(code, title="Code")
331334

@@ -335,13 +338,13 @@ def set_code(self, address, code):
335338
self._journaldb[code_hash] = code
336339
self._set_account(address, account.copy(code_hash=code_hash))
337340

338-
def get_code_hash(self, address):
341+
def get_code_hash(self, address: Address) -> Hash32:
339342
validate_canonical_address(address, title="Storage Address")
340343

341344
account = self._get_account(address)
342345
return account.code_hash
343346

344-
def delete_code(self, address):
347+
def delete_code(self, address: Address) -> None:
345348
validate_canonical_address(address, title="Storage Address")
346349

347350
account = self._get_account(address)
@@ -350,40 +353,40 @@ def delete_code(self, address):
350353
#
351354
# Account Methods
352355
#
353-
def account_has_code_or_nonce(self, address):
356+
def account_has_code_or_nonce(self, address: Address) -> bool:
354357
return self.get_nonce(address) != 0 or self.get_code_hash(address) != EMPTY_SHA3
355358

356-
def delete_account(self, address):
359+
def delete_account(self, address: Address) -> None:
357360
validate_canonical_address(address, title="Storage Address")
358361

359362
del self._journaltrie[address]
360363

361-
def account_exists(self, address):
364+
def account_exists(self, address: Address) -> bool:
362365
validate_canonical_address(address, title="Storage Address")
363366

364367
return self._journaltrie.get(address, b'') != b''
365368

366-
def touch_account(self, address):
369+
def touch_account(self, address: Address) -> None:
367370
validate_canonical_address(address, title="Storage Address")
368371

369372
account = self._get_account(address)
370373
self._set_account(address, account)
371374

372-
def account_is_empty(self, address):
375+
def account_is_empty(self, address: Address) -> bool:
373376
return not self.account_has_code_or_nonce(address) and self.get_balance(address) == 0
374377

375378
#
376379
# Internal
377380
#
378-
def _get_account(self, address, from_journal=True):
381+
def _get_account(self, address: Address, from_journal: bool=True) -> Account:
379382
rlp_account = (self._journaltrie if from_journal else self._trie_cache).get(address, b'')
380383
if rlp_account:
381384
account = rlp.decode(rlp_account, sedes=Account)
382385
else:
383386
account = Account()
384387
return account
385388

386-
def _set_account(self, address, account):
389+
def _set_account(self, address: Address, account: Account) -> None:
387390
rlp_account = rlp.encode(account, sedes=Account)
388391
self._journaltrie[address] = rlp_account
389392

@@ -424,7 +427,7 @@ def _log_pending_accounts(self) -> None:
424427
continue
425428
else:
426429
accounts_displayed.add(address)
427-
account = self._get_account(address)
430+
account = self._get_account(Address(address))
428431
self.logger.trace(
429432
"Account %s: balance %d, nonce %d, storage root %s, code hash %s",
430433
encode_hex(address),

eth/db/atomic.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from contextlib import contextmanager
22
import logging
3-
from typing import Generator
3+
from typing import (
4+
Generator,
5+
Iterator,
6+
)
47

58
from eth_utils import (
69
ValidationError,
@@ -110,14 +113,14 @@ def _exists(self, key: bytes) -> bool:
110113

111114
@classmethod
112115
@contextmanager
113-
def _commit_unless_raises(cls, write_target_db):
116+
def _commit_unless_raises(cls, write_target_db: BaseDB) -> Iterator['AtomicDBWriteBatch']:
114117
"""
115118
Commit all writes inside the context, unless an exception was raised.
116119
117120
Although this is technically an external API, it (and this whole class) is only intended
118121
to be used by AtomicDB.
119122
"""
120-
readable_write_batch = cls(write_target_db)
123+
readable_write_batch = cls(write_target_db) # type: AtomicDBWriteBatch
121124
try:
122125
yield readable_write_batch
123126
except Exception:

eth/db/backends/base.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,18 @@
66
MutableMapping,
77
)
88

9+
from typing import (
10+
Any,
11+
TYPE_CHECKING
12+
)
13+
14+
if TYPE_CHECKING:
15+
MM = MutableMapping[bytes, bytes]
16+
else:
17+
MM = MutableMapping
18+
919

10-
class BaseDB(MutableMapping, ABC):
20+
class BaseDB(MM, ABC):
1121
"""
1222
This is an abstract key/value lookup with all :class:`bytes` values,
1323
with some convenience methods for databases. As much as possible,
@@ -35,9 +45,10 @@ def set(self, key: bytes, value: bytes) -> None:
3545
def exists(self, key: bytes) -> bool:
3646
return self.__contains__(key)
3747

38-
def __contains__(self, key):
48+
def __contains__(self, key: bytes) -> bool: # type: ignore # Breaks LSP
3949
if hasattr(self, '_exists'):
40-
return self._exists(key)
50+
# Classes which inherit this class would have `_exists` attr
51+
return self._exists(key) # type: ignore
4152
else:
4253
return super().__contains__(key)
4354

@@ -47,10 +58,10 @@ def delete(self, key: bytes) -> None:
4758
except KeyError:
4859
return None
4960

50-
def __iter__(self):
51-
raise NotImplementedError("By default, DB classes cannot by iterated.")
61+
def __iter__(self) -> None:
62+
raise NotImplementedError("By default, DB classes cannot be iterated.")
5263

53-
def __len__(self):
64+
def __len__(self) -> int:
5465
raise NotImplementedError("By default, DB classes cannot return the total number of keys.")
5566

5667

@@ -80,5 +91,5 @@ class BaseAtomicDB(BaseDB):
8091
# or neither will
8192
"""
8293
@abstractmethod
83-
def atomic_batch(self):
94+
def atomic_batch(self) -> Any:
8495
raise NotImplementedError

eth/db/batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
3737
self.clear()
3838
self.logger.exception("Unexpected error occurred during batch update")
3939

40-
def clear(self):
40+
def clear(self) -> None:
4141
self._track_diff = DBDiffTracker()
4242

4343
def commit(self, apply_deletes: bool = True) -> None:

eth/db/cache.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from lru import LRU
22

3+
from typing import Any
4+
35
from eth.db.backends.base import BaseDB
46

57

@@ -8,24 +10,24 @@ class CacheDB(BaseDB):
810
Set and get decoded RLP objects, where the underlying db stores
911
encoded objects.
1012
"""
11-
def __init__(self, db, cache_size=2048):
13+
def __init__(self, db: BaseDB, cache_size: int=2048) -> None:
1214
self._db = db
1315
self._cache_size = cache_size
1416
self.reset_cache()
1517

16-
def reset_cache(self):
18+
def reset_cache(self) -> None:
1719
self._cached_values = LRU(self._cache_size)
1820

19-
def __getitem__(self, key):
21+
def __getitem__(self, key: Any) -> Any:
2022
if key not in self._cached_values:
2123
self._cached_values[key] = self._db[key]
2224
return self._cached_values[key]
2325

24-
def __setitem__(self, key, value):
26+
def __setitem__(self, key: Any, value: Any) -> None:
2527
self._cached_values[key] = value
2628
self._db[key] = value
2729

28-
def __delitem__(self, key):
30+
def __delitem__(self, key: Any) -> None:
2931
if key in self._cached_values:
3032
del self._cached_values[key]
3133
del self._db[key]

0 commit comments

Comments
 (0)