Skip to content

Commit f12a207

Browse files
authored
Merge pull request #1708 from hwwhww/fix_beacon_chain_db
Pass `block_class` explictly in `BeaconChainDB` APIs
2 parents ef52da2 + 058bc83 commit f12a207

File tree

5 files changed

+55
-48
lines changed

5 files changed

+55
-48
lines changed

eth/beacon/db/chain.py

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -53,38 +53,40 @@
5353

5454
class BaseBeaconChainDB(ABC):
5555
db = None # type: BaseAtomicDB
56-
block_class = None # type: Type[BaseBeaconBlock]
57-
58-
@abstractmethod
59-
def set_block_class(self, block_class: Type[BaseBeaconBlock]) -> None:
60-
pass
6156

6257
#
6358
# Block API
6459
#
6560
@abstractmethod
66-
def persist_block(self,
67-
block: BaseBeaconBlock) -> Tuple[Tuple[bytes, ...], Tuple[bytes, ...]]:
61+
def persist_block(
62+
self,
63+
block: BaseBeaconBlock,
64+
block_class: Type[BaseBeaconBlock]
65+
) -> Tuple[Tuple[bytes, ...], Tuple[bytes, ...]]:
6866
pass
6967

7068
@abstractmethod
7169
def get_canonical_block_root(self, slot: int) -> Hash32:
7270
pass
7371

7472
@abstractmethod
75-
def get_canonical_block_by_slot(self, slot: int) -> BaseBeaconBlock:
73+
def get_canonical_block_by_slot(self,
74+
slot: int,
75+
block_class: Type[BaseBeaconBlock]) -> BaseBeaconBlock:
7676
pass
7777

7878
@abstractmethod
7979
def get_canonical_block_root_by_slot(self, slot: int) -> Hash32:
8080
pass
8181

8282
@abstractmethod
83-
def get_canonical_head(self) -> BaseBeaconBlock:
83+
def get_canonical_head(self, block_class: Type[BaseBeaconBlock]) -> BaseBeaconBlock:
8484
pass
8585

8686
@abstractmethod
87-
def get_block_by_root(self, block_root: Hash32) -> BaseBeaconBlock:
87+
def get_block_by_root(self,
88+
block_root: Hash32,
89+
block_class: Type[BaseBeaconBlock]) -> BaseBeaconBlock:
8890
pass
8991

9092
@abstractmethod
@@ -97,8 +99,9 @@ def block_exists(self, block_root: Hash32) -> bool:
9799

98100
@abstractmethod
99101
def persist_block_chain(
100-
self,
101-
blocks: Iterable[BaseBeaconBlock]
102+
self,
103+
blocks: Iterable[BaseBeaconBlock],
104+
block_class: Type[BaseBeaconBlock]
102105
) -> Tuple[Tuple[BaseBeaconBlock, ...], Tuple[BaseBeaconBlock, ...]]:
103106
pass
104107

@@ -127,20 +130,19 @@ def get(self, key: bytes) -> bytes:
127130

128131

129132
class BeaconChainDB(BaseBeaconChainDB):
130-
def __init__(self, db: BaseAtomicDB, block_class: Type[BaseBeaconBlock]) -> None:
133+
def __init__(self, db: BaseAtomicDB) -> None:
131134
self.db = db
132-
self.block_class = block_class
133-
134-
def set_block_class(self, block_class: Type[BaseBeaconBlock]) -> None:
135-
self.block_class = block_class
136135

137-
def persist_block(self,
138-
block: BaseBeaconBlock) -> Tuple[Tuple[bytes, ...], Tuple[bytes, ...]]:
136+
def persist_block(
137+
self,
138+
block: BaseBeaconBlock,
139+
block_class: Type[BaseBeaconBlock]
140+
) -> Tuple[Tuple[bytes, ...], Tuple[bytes, ...]]:
139141
"""
140142
Persist the given block.
141143
"""
142144
with self.db.atomic_batch() as db:
143-
return self._persist_block(db, block, self.block_class)
145+
return self._persist_block(db, block, block_class)
144146

145147
@classmethod
146148
def _persist_block(
@@ -188,14 +190,16 @@ def _get_canonical_block_root(db: BaseDB, slot: int) -> Hash32:
188190
else:
189191
return rlp.decode(encoded_key, sedes=rlp.sedes.binary)
190192

191-
def get_canonical_block_by_slot(self, slot: int) -> BaseBeaconBlock:
193+
def get_canonical_block_by_slot(self,
194+
slot: int,
195+
block_class: Type[BaseBeaconBlock]) -> BaseBeaconBlock:
192196
"""
193197
Return the block with the given slot in the canonical chain.
194198
195199
Raise BlockNotFound if there's no block with the given slot in the
196200
canonical chain.
197201
"""
198-
return self._get_canonical_block_by_slot(self.db, slot, self.block_class)
202+
return self._get_canonical_block_by_slot(self.db, slot, block_class)
199203

200204
@classmethod
201205
def _get_canonical_block_by_slot(
@@ -223,11 +227,11 @@ def _get_canonical_block_root_by_slot(
223227
validate_slot(slot)
224228
return cls._get_canonical_block_root(db, slot)
225229

226-
def get_canonical_head(self) -> BaseBeaconBlock:
230+
def get_canonical_head(self, block_class: Type[BaseBeaconBlock]) -> BaseBeaconBlock:
227231
"""
228232
Return the current block at the head of the chain.
229233
"""
230-
return self._get_canonical_head(self.db, self.block_class)
234+
return self._get_canonical_head(self.db, block_class)
231235

232236
@classmethod
233237
def _get_canonical_head(cls,
@@ -239,8 +243,10 @@ def _get_canonical_head(cls,
239243
raise CanonicalHeadNotFound("No canonical head set for this chain")
240244
return cls._get_block_by_root(db, Hash32(canonical_head_root), block_class)
241245

242-
def get_block_by_root(self, block_root: Hash32) -> BaseBeaconBlock:
243-
return self._get_block_by_root(self.db, block_root, self.block_class)
246+
def get_block_by_root(self,
247+
block_root: Hash32,
248+
block_class: Type[BaseBeaconBlock]) -> BaseBeaconBlock:
249+
return self._get_block_by_root(self.db, block_root, block_class)
244250

245251
@staticmethod
246252
def _get_block_by_root(db: BaseDB,
@@ -280,15 +286,16 @@ def _block_exists(db: BaseDB, block_root: Hash32) -> bool:
280286
return block_root in db
281287

282288
def persist_block_chain(
283-
self,
284-
blocks: Iterable[BaseBeaconBlock]
289+
self,
290+
blocks: Iterable[BaseBeaconBlock],
291+
block_class: Type[BaseBeaconBlock]
285292
) -> Tuple[Tuple[BaseBeaconBlock, ...], Tuple[BaseBeaconBlock, ...]]:
286293
"""
287294
Return two iterable of blocks, the first containing the new canonical blocks,
288295
the second containing the old canonical headers
289296
"""
290297
with self.db.atomic_batch() as db:
291-
return self._persist_block_chain(db, blocks, self.block_class)
298+
return self._persist_block_chain(db, blocks, block_class)
292299

293300
@classmethod
294301
def _set_block_scores_to_db(

eth/beacon/state_machines/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ def import_block(self, block: BaseBeaconBlock) -> Tuple[BeaconState, BaseBeaconB
7171
class BeaconStateMachine(BaseBeaconStateMachine):
7272
def __init__(self,
7373
chaindb: BaseBeaconChainDB,
74-
block_root: Hash32) -> None:
74+
block_root: Hash32,
75+
block_class: Type[BaseBeaconBlock]) -> None:
7576
self.chaindb = chaindb
76-
self.chaindb.set_block_class(self.get_block_class())
77-
self.block = self.chaindb.get_block_by_root(block_root)
77+
self.block = self.chaindb.get_block_by_root(block_root, block_class)
7878

7979
@property
8080
def state(self) -> BeaconState:

eth/beacon/types/blocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def from_root(cls, root: Hash32, chaindb: 'BaseBeaconChainDB') -> 'BeaconBlock':
173173
"""
174174
Returns the block denoted by the given block header.
175175
"""
176-
block = chaindb.get_block_by_root(root)
176+
block = chaindb.get_block_by_root(root, cls)
177177
body = cls.block_body_class(
178178
proposer_slashings=block.body.proposer_slashings,
179179
casper_slashings=block.body.casper_slashings,

tests/beacon/db/test_beacon_chaindb.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
@pytest.fixture
3535
def chaindb(base_db):
36-
return BeaconChainDB(base_db, BeaconBlock)
36+
return BeaconChainDB(base_db)
3737

3838

3939
@pytest.fixture(params=[0, 10, 999])
@@ -52,33 +52,33 @@ def state(sample_beacon_state_params):
5252
def test_chaindb_add_block_number_to_root_lookup(chaindb, block):
5353
block_slot_to_root_key = SchemaV1.make_block_slot_to_root_lookup_key(block.slot)
5454
assert not chaindb.exists(block_slot_to_root_key)
55-
chaindb.persist_block(block)
55+
chaindb.persist_block(block, block.__class__)
5656
assert chaindb.exists(block_slot_to_root_key)
5757

5858

5959
def test_chaindb_persist_block_and_slot_to_root(chaindb, block):
6060
with pytest.raises(BlockNotFound):
61-
chaindb.get_block_by_root(block.root)
61+
chaindb.get_block_by_root(block.root, block.__class__)
6262
slot_to_root_key = SchemaV1.make_block_root_to_score_lookup_key(block.root)
6363
assert not chaindb.exists(slot_to_root_key)
6464

65-
chaindb.persist_block(block)
65+
chaindb.persist_block(block, block.__class__)
6666

67-
assert chaindb.get_block_by_root(block.root) == block
67+
assert chaindb.get_block_by_root(block.root, block.__class__) == block
6868
assert chaindb.exists(slot_to_root_key)
6969

7070

7171
@given(seed=st.binary(min_size=32, max_size=32))
7272
def test_chaindb_persist_block_and_unknown_parent(chaindb, block, seed):
7373
n_block = block.copy(parent_root=hash_eth2(seed))
7474
with pytest.raises(ParentNotFound):
75-
chaindb.persist_block(n_block)
75+
chaindb.persist_block(n_block, n_block.__class__)
7676

7777

7878
def test_chaindb_persist_block_and_block_to_root(chaindb, block):
7979
block_to_root_key = SchemaV1.make_block_root_to_score_lookup_key(block.root)
8080
assert not chaindb.exists(block_to_root_key)
81-
chaindb.persist_block(block)
81+
chaindb.persist_block(block, block.__class__)
8282
assert chaindb.exists(block_to_root_key)
8383

8484

@@ -87,7 +87,7 @@ def test_chaindb_get_score(chaindb, sample_beacon_block_params):
8787
parent_root=GENESIS_PARENT_HASH,
8888
slot=0,
8989
)
90-
chaindb.persist_block(genesis)
90+
chaindb.persist_block(genesis, genesis.__class__)
9191

9292
genesis_score_key = SchemaV1.make_block_root_to_score_lookup_key(genesis.root)
9393
genesis_score = rlp.decode(chaindb.db.get(genesis_score_key), sedes=rlp.sedes.big_endian_int)
@@ -98,7 +98,7 @@ def test_chaindb_get_score(chaindb, sample_beacon_block_params):
9898
parent_root=genesis.root,
9999
slot=1,
100100
)
101-
chaindb.persist_block(block1)
101+
chaindb.persist_block(block1, block1.__class__)
102102

103103
block1_score_key = SchemaV1.make_block_root_to_score_lookup_key(block1.root)
104104
block1_score = rlp.decode(chaindb.db.get(block1_score_key), sedes=rlp.sedes.big_endian_int)
@@ -107,13 +107,13 @@ def test_chaindb_get_score(chaindb, sample_beacon_block_params):
107107

108108

109109
def test_chaindb_get_block_by_root(chaindb, block):
110-
chaindb.persist_block(block)
111-
result_block = chaindb.get_block_by_root(block.root)
110+
chaindb.persist_block(block, block.__class__)
111+
result_block = chaindb.get_block_by_root(block.root, block.__class__)
112112
validate_rlp_equal(result_block, block)
113113

114114

115115
def test_chaindb_get_canonical_block_root(chaindb, block):
116-
chaindb.persist_block(block)
116+
chaindb.persist_block(block, block.__class__)
117117
block_root = chaindb.get_canonical_block_root(block.slot)
118118
assert block_root == block.root
119119

tests/beacon/state_machines/test_demo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_demo(base_db,
3636
config,
3737
privkeys,
3838
pubkeys):
39-
chaindb = BeaconChainDB(base_db, SerenityBeaconBlock)
39+
chaindb = BeaconChainDB(base_db)
4040
state = genesis_state
4141
block = SerenityBeaconBlock(**sample_beacon_block_params).copy(
4242
slot=state.slot + 2,
@@ -68,11 +68,11 @@ def test_demo(base_db,
6868
)
6969

7070
# Store in chaindb
71-
chaindb.persist_block(block)
71+
chaindb.persist_block(block, SerenityBeaconBlock)
7272
chaindb.persist_state(state)
7373

7474
# Get state machine instance
75-
sm = fixture_sm_class(chaindb, block.root)
75+
sm = fixture_sm_class(chaindb, block.root, SerenityBeaconBlock)
7676
result_state, _ = sm.import_block(block)
7777

7878
assert state.slot == 0

0 commit comments

Comments
 (0)