Skip to content

Commit 2678689

Browse files
Bhargavasomucburgdorf
authored andcommitted
Enable complete type hinting for eth.chains
1 parent 017e206 commit 2678689

File tree

4 files changed

+103
-58
lines changed

4 files changed

+103
-58
lines changed

eth/chains/base.py

Lines changed: 65 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
cast,
1313
Dict,
1414
Generator,
15+
Iterable,
1516
Iterator,
1617
List,
1718
Optional,
@@ -32,6 +33,12 @@
3233
encode_hex,
3334
)
3435

36+
from eth.constants import (
37+
BLANK_ROOT_HASH,
38+
EMPTY_UNCLE_HASH,
39+
MAX_UNCLE_DEPTH,
40+
)
41+
3542
from eth.db.backends.base import BaseAtomicDB
3643
from eth.db.chain import (
3744
BaseChainDB,
@@ -40,11 +47,7 @@
4047
from eth.db.header import (
4148
HeaderDB,
4249
)
43-
from eth.constants import (
44-
BLANK_ROOT_HASH,
45-
EMPTY_UNCLE_HASH,
46-
MAX_UNCLE_DEPTH,
47-
)
50+
4851
from eth.estimators import (
4952
get_gas_estimator,
5053
)
@@ -53,15 +56,7 @@
5356
TransactionNotFound,
5457
VMNotFound,
5558
)
56-
from eth.utils.spoof import (
57-
SpoofTransaction,
58-
)
59-
from eth.validation import (
60-
validate_block_number,
61-
validate_uint256,
62-
validate_word,
63-
validate_vm_configuration,
64-
)
59+
6560
from eth.rlp.blocks import (
6661
BaseBlock,
6762
)
@@ -76,6 +71,14 @@
7671
BaseTransaction,
7772
BaseUnsignedTransaction,
7873
)
74+
75+
from eth.typing import (
76+
AccountState,
77+
)
78+
79+
from eth.utils.spoof import (
80+
SpoofTransaction,
81+
)
7982
from eth.utils.db import (
8083
apply_state_dict,
8184
)
@@ -88,9 +91,15 @@
8891
from eth.utils.rlp import (
8992
validate_imported_block_unchanged,
9093
)
91-
from eth.typing import (
92-
AccountState,
94+
95+
from eth.validation import (
96+
validate_block_number,
97+
validate_uint256,
98+
validate_word,
99+
validate_vm_configuration,
93100
)
101+
from eth.vm.computation import BaseComputation
102+
from eth.vm.state import BaseState # noqa: F401
94103

95104
from eth._warnings import catch_and_ignore_import_warning
96105
with catch_and_ignore_import_warning():
@@ -107,7 +116,10 @@
107116
)
108117

109118
if TYPE_CHECKING:
110-
from eth.vm.base import BaseVM # noqa: F401
119+
from eth.vm.base import ( # noqa: F401
120+
BaseVM,
121+
VM,
122+
)
111123

112124

113125
class BaseChain(Configurable, ABC):
@@ -164,7 +176,7 @@ def get_vm_class(cls, header: BlockHeader) -> Type['BaseVM']:
164176
return cls.get_vm_class_for_block_number(header.block_number)
165177

166178
@abstractmethod
167-
def get_vm(self, header: BlockHeader=None) -> 'BaseVM':
179+
def get_vm(self, header: BlockHeader=None) -> 'VM':
168180
raise NotImplementedError("Chain classes must implement this method")
169181

170182
@classmethod
@@ -196,11 +208,11 @@ def get_block_header_by_hash(self, block_hash: Hash32) -> BlockHeader:
196208
raise NotImplementedError("Chain classes must implement this method")
197209

198210
@abstractmethod
199-
def get_canonical_head(self):
211+
def get_canonical_head(self) -> BlockHeader:
200212
raise NotImplementedError("Chain classes must implement this method")
201213

202214
@abstractmethod
203-
def get_score(self, block_hash):
215+
def get_score(self, block_hash: Hash32) -> int:
204216
raise NotImplementedError("Chain classes must implement this method")
205217

206218
#
@@ -227,11 +239,14 @@ def get_canonical_block_by_number(self, block_number: BlockNumber) -> BaseBlock:
227239
raise NotImplementedError("Chain classes must implement this method")
228240

229241
@abstractmethod
230-
def get_canonical_block_hash(self, block_number):
242+
def get_canonical_block_hash(self, block_number: BlockNumber) -> Hash32:
231243
raise NotImplementedError("Chain classes must implement this method")
232244

233245
@abstractmethod
234-
def build_block_with_transactions(self, transactions, parent_header):
246+
def build_block_with_transactions(self,
247+
transactions: Tuple[BaseTransaction, ...],
248+
parent_header: BlockHeader=None
249+
) -> Tuple[BaseBlock, Tuple[Receipt, ...], Tuple[BaseComputation, ...]]: # noqa: E501
235250
raise NotImplementedError("Chain classes must implement this method")
236251

237252
#
@@ -320,7 +335,7 @@ class Chain(BaseChain):
320335
current block number.
321336
"""
322337
logger = logging.getLogger("eth.chain.chain.Chain")
323-
gas_estimator = None # type: Callable
338+
gas_estimator = None # type: Callable[[BaseState, BaseTransaction], int]
324339

325340
chaindb_class = ChainDB # type: Type[BaseChainDB]
326341

@@ -334,8 +349,8 @@ def __init__(self, base_db: BaseAtomicDB) -> None:
334349

335350
self.chaindb = self.get_chaindb_class()(base_db)
336351
self.headerdb = HeaderDB(base_db)
337-
if self.gas_estimator is None:
338-
self.gas_estimator = get_gas_estimator() # type: ignore
352+
if self.gas_estimator is None: # type: ignore
353+
self.gas_estimator = get_gas_estimator() # type: ignore
339354

340355
#
341356
# Helpers
@@ -403,7 +418,7 @@ def from_genesis_header(cls,
403418
#
404419
# VM API
405420
#
406-
def get_vm(self, at_header: BlockHeader=None) -> 'BaseVM':
421+
def get_vm(self, at_header: BlockHeader=None) -> 'VM':
407422
"""
408423
Returns the VM instance for the given block number.
409424
"""
@@ -414,7 +429,9 @@ def get_vm(self, at_header: BlockHeader=None) -> 'BaseVM':
414429
#
415430
# Header API
416431
#
417-
def create_header_from_parent(self, parent_header, **header_params):
432+
def create_header_from_parent(self,
433+
parent_header: BlockHeader,
434+
**header_params: HeaderParams) -> BlockHeader:
418435
"""
419436
Passthrough helper to the VM class of the block descending from the
420437
given header.
@@ -432,15 +449,15 @@ def get_block_header_by_hash(self, block_hash: Hash32) -> BlockHeader:
432449
validate_word(block_hash, title="Block Hash")
433450
return self.chaindb.get_block_header_by_hash(block_hash)
434451

435-
def get_canonical_head(self):
452+
def get_canonical_head(self) -> BlockHeader:
436453
"""
437454
Returns the block header at the canonical chain head.
438455
439456
Raises CanonicalHeadNotFound if there's no head defined for the canonical chain.
440457
"""
441458
return self.chaindb.get_canonical_head()
442459

443-
def get_score(self, block_hash):
460+
def get_score(self, block_hash: Hash32) -> int:
444461
"""
445462
Returns the difficulty score of the block with the given hash.
446463
@@ -498,7 +515,7 @@ def get_block_by_hash(self, block_hash: Hash32) -> BaseBlock:
498515
block_header = self.get_block_header_by_hash(block_hash)
499516
return self.get_block_by_header(block_header)
500517

501-
def get_block_by_header(self, block_header):
518+
def get_block_by_header(self, block_header: BlockHeader) -> BaseBlock:
502519
"""
503520
Returns the requested block as specified by the block header.
504521
"""
@@ -524,7 +541,10 @@ def get_canonical_block_hash(self, block_number: BlockNumber) -> Hash32:
524541
"""
525542
return self.chaindb.get_canonical_block_hash(block_number)
526543

527-
def build_block_with_transactions(self, transactions, parent_header=None):
544+
def build_block_with_transactions(self,
545+
transactions: Tuple[BaseTransaction, ...],
546+
parent_header: BlockHeader=None
547+
) -> Tuple[BaseBlock, Tuple[Receipt, ...], Tuple[BaseComputation, ...]]: # noqa: E501
528548
"""
529549
Generate a block with the provided transactions. This does *not* import
530550
that block into your chain. If you want this new block in your chain,
@@ -554,12 +574,12 @@ def get_canonical_transaction(self, transaction_hash: Hash32) -> BaseTransaction
554574
found in the main chain.
555575
"""
556576
(block_num, index) = self.chaindb.get_transaction_index(transaction_hash)
557-
VM = self.get_vm_class_for_block_number(block_num)
577+
VM_class = self.get_vm_class_for_block_number(block_num)
558578

559579
transaction = self.chaindb.get_transaction_by_index(
560580
block_num,
561581
index,
562-
VM.get_transaction_class(),
582+
VM_class.get_transaction_class(),
563583
)
564584

565585
if transaction.hash == transaction_hash:
@@ -627,7 +647,7 @@ def estimate_gas(
627647
if at_header is None:
628648
at_header = self.get_canonical_head()
629649
with self.get_vm(at_header).state_in_temp_block() as state:
630-
return self.gas_estimator(state, transaction)
650+
return self.gas_estimator(state, transaction) # type: ignore
631651

632652
def import_block(self,
633653
block: BaseBlock,
@@ -689,8 +709,8 @@ def import_block(self,
689709
# Validation API
690710
#
691711
def validate_receipt(self, receipt: Receipt, at_header: BlockHeader) -> None:
692-
VM = self.get_vm_class(at_header)
693-
VM.validate_receipt(receipt)
712+
VM_class = self.get_vm_class(at_header)
713+
VM_class.validate_receipt(receipt)
694714

695715
def validate_block(self, block: BaseBlock) -> None:
696716
"""
@@ -704,18 +724,18 @@ def validate_block(self, block: BaseBlock) -> None:
704724
"""
705725
if block.is_genesis:
706726
raise ValidationError("Cannot validate genesis block this way")
707-
VM = self.get_vm_class_for_block_number(BlockNumber(block.number))
727+
VM_class = self.get_vm_class_for_block_number(BlockNumber(block.number))
708728
parent_block = self.get_block_by_hash(block.header.parent_hash)
709-
VM.validate_header(block.header, parent_block.header, check_seal=True)
729+
VM_class.validate_header(block.header, parent_block.header, check_seal=True)
710730
self.validate_uncles(block)
711731
self.validate_gaslimit(block.header)
712732

713733
def validate_seal(self, header: BlockHeader) -> None:
714734
"""
715735
Validate the seal on the given header.
716736
"""
717-
VM = self.get_vm_class_for_block_number(BlockNumber(header.block_number))
718-
VM.validate_seal(header)
737+
VM_class = self.get_vm_class_for_block_number(BlockNumber(header.block_number))
738+
VM_class.validate_seal(header)
719739

720740
def validate_gaslimit(self, header: BlockHeader) -> None:
721741
"""
@@ -830,7 +850,7 @@ def validate_chain(
830850

831851

832852
@to_set
833-
def _extract_uncle_hashes(blocks):
853+
def _extract_uncle_hashes(blocks: Iterable[BaseBlock]) -> Iterable[Hash32]:
834854
for block in blocks:
835855
for uncle in block.uncles:
836856
yield uncle.hash
@@ -843,7 +863,9 @@ def __init__(self, base_db: BaseAtomicDB, header: BlockHeader=None) -> None:
843863
super().__init__(base_db)
844864
self.header = self.ensure_header(header)
845865

846-
def apply_transaction(self, transaction):
866+
def apply_transaction(self,
867+
transaction: BaseTransaction
868+
) -> Tuple[BaseBlock, Receipt, BaseComputation]:
847869
"""
848870
Applies the transaction to the current tip block.
849871
@@ -890,7 +912,7 @@ def mine_block(self, *args: Any, **kwargs: Any) -> BaseBlock:
890912
self.header = self.create_header_from_parent(mined_block.header)
891913
return mined_block
892914

893-
def get_vm(self, at_header: BlockHeader=None) -> 'BaseVM':
915+
def get_vm(self, at_header: BlockHeader=None) -> 'VM':
894916
if at_header is None:
895917
at_header = self.header
896918

eth/chains/header.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
from abc import ABC, abstractmethod
2-
from typing import Dict, Any, Tuple, Type # noqa: F401
2+
from typing import ( # noqa: F401
3+
Any,
4+
cast,
5+
Dict,
6+
Tuple,
7+
Type,
8+
)
39

410
from eth_typing import (
511
BlockNumber,
612
Hash32,
713
)
814

9-
from eth.db.backends.base import BaseDB
15+
from eth.db.backends.base import (
16+
BaseAtomicDB,
17+
BaseDB,
18+
)
1019
from eth.db.header import ( # noqa: F401
1120
BaseHeaderDB,
1221
HeaderDB,
@@ -47,7 +56,7 @@ def from_genesis_header(cls,
4756
#
4857
@classmethod
4958
@abstractmethod
50-
def get_headerdb_class(cls):
59+
def get_headerdb_class(cls) -> Type[BaseHeaderDB]:
5160
raise NotImplementedError("Chain classes must implement this method")
5261

5362
#
@@ -73,7 +82,9 @@ def header_exists(self, block_hash: Hash32) -> bool:
7382
raise NotImplementedError("Chain classes must implement this method")
7483

7584
@abstractmethod
76-
def import_header(self, header: BlockHeader) -> Tuple[BlockHeader, ...]:
85+
def import_header(self,
86+
header: BlockHeader
87+
) -> Tuple[Tuple[BlockHeader, ...], Tuple[BlockHeader, ...]]:
7788
raise NotImplementedError("Chain classes must implement this method")
7889

7990

@@ -82,7 +93,7 @@ class HeaderChain(BaseHeaderChain):
8293

8394
def __init__(self, base_db: BaseDB, header: BlockHeader=None) -> None:
8495
self.base_db = base_db
85-
self.headerdb = self.get_headerdb_class()(base_db)
96+
self.headerdb = self.get_headerdb_class()(cast(BaseAtomicDB, base_db))
8697

8798
if header is None:
8899
self.header = self.get_canonical_head()
@@ -99,15 +110,15 @@ def from_genesis_header(cls,
99110
"""
100111
Initializes the chain from the genesis header.
101112
"""
102-
headerdb = cls.get_headerdb_class()(base_db)
113+
headerdb = cls.get_headerdb_class()(cast(BaseAtomicDB, base_db))
103114
headerdb.persist_header(genesis_header)
104115
return cls(base_db, genesis_header)
105116

106117
#
107118
# Helpers
108119
#
109120
@classmethod
110-
def get_headerdb_class(cls):
121+
def get_headerdb_class(cls) -> Type[BaseHeaderDB]:
111122
"""
112123
Returns the class which should be used for the `headerdb`
113124
"""
@@ -151,7 +162,9 @@ def header_exists(self, block_hash: Hash32) -> bool:
151162
"""
152163
return self.headerdb.header_exists(block_hash)
153164

154-
def import_header(self, header: BlockHeader) -> Tuple[BlockHeader, ...]:
165+
def import_header(self,
166+
header: BlockHeader
167+
) -> Tuple[Tuple[BlockHeader, ...], Tuple[BlockHeader, ...]]:
155168
"""
156169
Direct passthrough to `headerdb`
157170

0 commit comments

Comments
 (0)