Skip to content

Commit 8ecc627

Browse files
authored
Merge pull request #1679 from hwwhww/state_machine_init
Update StateMachine.__init__() and add `BeaconBlock`
2 parents 6259296 + a3d7c0b commit 8ecc627

File tree

11 files changed

+186
-69
lines changed

11 files changed

+186
-69
lines changed

eth/beacon/db/chain.py

Lines changed: 60 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import (
55
Iterable,
66
Tuple,
7+
Type,
78
)
89
from cytoolz import (
910
first,
@@ -38,7 +39,10 @@
3839
)
3940

4041
from eth.beacon.types.states import BeaconState # noqa: F401
41-
from eth.beacon.types.blocks import BaseBeaconBlock # noqa: F401
42+
from eth.beacon.types.blocks import ( # noqa: F401
43+
BaseBeaconBlock,
44+
BeaconBlock,
45+
)
4246
from eth.beacon.validation import (
4347
validate_slot,
4448
)
@@ -48,6 +52,11 @@
4852

4953
class BaseBeaconChainDB(ABC):
5054
db = None # type: BaseAtomicDB
55+
block_class = None # type: Type[BaseBeaconBlock]
56+
57+
@abstractmethod
58+
def set_block_class(self, block_class: Type[BaseBeaconBlock]) -> None:
59+
pass
5160

5261
#
5362
# Block API
@@ -117,24 +126,33 @@ def get(self, key: bytes) -> bytes:
117126

118127

119128
class BeaconChainDB(BaseBeaconChainDB):
120-
def __init__(self, db: BaseAtomicDB) -> None:
129+
def __init__(self, db: BaseAtomicDB, block_class: Type[BaseBeaconBlock]) -> None:
121130
self.db = db
131+
self.block_class = block_class
132+
133+
def set_block_class(self, block_class: Type[BaseBeaconBlock]) -> None:
134+
self.block_class = block_class
122135

123136
def persist_block(self,
124137
block: BaseBeaconBlock) -> Tuple[Tuple[bytes, ...], Tuple[bytes, ...]]:
125138
"""
126139
Persist the given block.
127140
"""
128141
with self.db.atomic_batch() as db:
129-
return self._persist_block(db, block)
142+
return self._persist_block(db, block, self.block_class)
130143

131144
@classmethod
132145
def _persist_block(
133146
cls,
134147
db: 'BaseDB',
135-
block: BaseBeaconBlock) -> Tuple[Tuple[bytes, ...], Tuple[bytes, ...]]:
148+
block: BaseBeaconBlock,
149+
block_class: Type[BaseBeaconBlock]) -> Tuple[Tuple[bytes, ...], Tuple[bytes, ...]]:
136150
block_chain = (block, )
137-
new_canonical_blocks, old_canonical_blocks = cls._persist_block_chain(db, block_chain)
151+
new_canonical_blocks, old_canonical_blocks = cls._persist_block_chain(
152+
db,
153+
block_chain,
154+
block_class,
155+
)
138156

139157
return new_canonical_blocks, old_canonical_blocks
140158

@@ -176,15 +194,16 @@ def get_canonical_block_by_slot(self, slot: int) -> BaseBeaconBlock:
176194
Raise BlockNotFound if there's no block with the given slot in the
177195
canonical chain.
178196
"""
179-
return self._get_canonical_block_by_slot(self.db, slot)
197+
return self._get_canonical_block_by_slot(self.db, slot, self.block_class)
180198

181199
@classmethod
182200
def _get_canonical_block_by_slot(
183201
cls,
184202
db: BaseDB,
185-
slot: int) -> BaseBeaconBlock:
203+
slot: int,
204+
block_class: Type[BaseBeaconBlock]) -> BaseBeaconBlock:
186205
canonical_block_root = cls._get_canonical_block_root_by_slot(db, slot)
187-
return cls._get_block_by_root(db, canonical_block_root)
206+
return cls._get_block_by_root(db, canonical_block_root, block_class)
188207

189208
def get_canonical_block_root_by_slot(self, slot: int) -> Hash32:
190209
"""
@@ -207,21 +226,25 @@ def get_canonical_head(self) -> BaseBeaconBlock:
207226
"""
208227
Return the current block at the head of the chain.
209228
"""
210-
return self._get_canonical_head(self.db)
229+
return self._get_canonical_head(self.db, self.block_class)
211230

212231
@classmethod
213-
def _get_canonical_head(cls, db: BaseDB) -> BaseBeaconBlock:
232+
def _get_canonical_head(cls,
233+
db: BaseDB,
234+
block_class: Type[BaseBeaconBlock]) -> BaseBeaconBlock:
214235
try:
215236
canonical_head_root = db[SchemaV1.make_canonical_head_root_lookup_key()]
216237
except KeyError:
217238
raise CanonicalHeadNotFound("No canonical head set for this chain")
218-
return cls._get_block_by_root(db, Hash32(canonical_head_root))
239+
return cls._get_block_by_root(db, Hash32(canonical_head_root), block_class)
219240

220241
def get_block_by_root(self, block_root: Hash32) -> BaseBeaconBlock:
221-
return self._get_block_by_root(self.db, block_root)
242+
return self._get_block_by_root(self.db, block_root, self.block_class)
222243

223244
@staticmethod
224-
def _get_block_by_root(db: BaseDB, block_root: Hash32) -> BaseBeaconBlock:
245+
def _get_block_by_root(db: BaseDB,
246+
block_root: Hash32,
247+
block_class: Type[BaseBeaconBlock]) -> BaseBeaconBlock:
225248
"""
226249
Return the requested block header as specified by block root.
227250
@@ -233,7 +256,7 @@ def _get_block_by_root(db: BaseDB, block_root: Hash32) -> BaseBeaconBlock:
233256
except KeyError:
234257
raise BlockNotFound("No block with root {0} found".format(
235258
encode_hex(block_root)))
236-
return _decode_block(block_rlp)
259+
return _decode_block(block_rlp, block_class)
237260

238261
def get_score(self, block_root: Hash32) -> int:
239262
return self._get_score(self.db, block_root)
@@ -264,13 +287,14 @@ def persist_block_chain(
264287
the second containing the old canonical headers
265288
"""
266289
with self.db.atomic_batch() as db:
267-
return self._persist_block_chain(db, blocks)
290+
return self._persist_block_chain(db, blocks, self.block_class)
268291

269292
@classmethod
270293
def _persist_block_chain(
271294
cls,
272295
db: BaseDB,
273-
blocks: Iterable[BaseBeaconBlock]
296+
blocks: Iterable[BaseBeaconBlock],
297+
block_class: Type[BaseBeaconBlock]
274298
) -> Tuple[Tuple[BaseBeaconBlock, ...], Tuple[BaseBeaconBlock, ...]]:
275299
try:
276300
first_block = first(blocks)
@@ -313,20 +337,23 @@ def _persist_block_chain(
313337
)
314338

315339
try:
316-
previous_canonical_head = cls._get_canonical_head(db).root
340+
previous_canonical_head = cls._get_canonical_head(db, block_class).root
317341
head_score = cls._get_score(db, previous_canonical_head)
318342
except CanonicalHeadNotFound:
319-
return cls._set_as_canonical_chain_head(db, block.root)
343+
return cls._set_as_canonical_chain_head(db, block.root, block_class)
320344

321345
if score > head_score:
322-
return cls._set_as_canonical_chain_head(db, block.root)
346+
return cls._set_as_canonical_chain_head(db, block.root, block_class)
323347
else:
324348
return tuple(), tuple()
325349

326350
@classmethod
327351
def _set_as_canonical_chain_head(
328-
cls, db: BaseDB,
329-
block_root: Hash32) -> Tuple[Tuple[BaseBeaconBlock, ...], Tuple[BaseBeaconBlock, ...]]:
352+
cls,
353+
db: BaseDB,
354+
block_root: Hash32,
355+
block_class: Type[BaseBeaconBlock]
356+
) -> Tuple[Tuple[BaseBeaconBlock, ...], Tuple[BaseBeaconBlock, ...]]:
330357
"""
331358
Set the canonical chain HEAD to the block as specified by the
332359
given block root.
@@ -335,13 +362,13 @@ def _set_as_canonical_chain_head(
335362
are no longer in the canonical chain
336363
"""
337364
try:
338-
block = cls._get_block_by_root(db, block_root)
365+
block = cls._get_block_by_root(db, block_root, block_class)
339366
except BlockNotFound:
340367
raise ValueError(
341368
"Cannot use unknown block root as canonical head: {}".format(block_root)
342369
)
343370

344-
new_canonical_blocks = tuple(reversed(cls._find_new_ancestors(db, block)))
371+
new_canonical_blocks = tuple(reversed(cls._find_new_ancestors(db, block, block_class)))
345372
old_canonical_blocks = []
346373

347374
for block in new_canonical_blocks:
@@ -351,7 +378,7 @@ def _set_as_canonical_chain_head(
351378
# no old_canonical block, and no more possible
352379
break
353380
else:
354-
old_canonical_block = cls._get_block_by_root(db, old_canonical_root)
381+
old_canonical_block = cls._get_block_by_root(db, old_canonical_root, block_class)
355382
old_canonical_blocks.append(old_canonical_block)
356383

357384
for block in new_canonical_blocks:
@@ -363,7 +390,11 @@ def _set_as_canonical_chain_head(
363390

364391
@classmethod
365392
@to_tuple
366-
def _find_new_ancestors(cls, db: BaseDB, block: BaseBeaconBlock) -> Iterable[BaseBeaconBlock]:
393+
def _find_new_ancestors(
394+
cls,
395+
db: BaseDB,
396+
block: BaseBeaconBlock,
397+
block_class: Type[BaseBeaconBlock]) -> Iterable[BaseBeaconBlock]:
367398
"""
368399
Return the chain leading up from the given block until (but not including)
369400
the first ancestor it has in common with our canonical chain.
@@ -377,7 +408,7 @@ def _find_new_ancestors(cls, db: BaseDB, block: BaseBeaconBlock) -> Iterable[Bas
377408
"""
378409
while True:
379410
try:
380-
orig = cls._get_canonical_block_by_slot(db, block.slot)
411+
orig = cls._get_canonical_block_by_slot(db, block.slot, block_class)
381412
except BlockNotFound:
382413
# This just means the block is not on the canonical chain.
383414
pass
@@ -392,7 +423,7 @@ def _find_new_ancestors(cls, db: BaseDB, block: BaseBeaconBlock) -> Iterable[Bas
392423
if block.parent_root == GENESIS_PARENT_HASH:
393424
break
394425
else:
395-
block = cls._get_block_by_root(db, block.parent_root)
426+
block = cls._get_block_by_root(db, block.parent_root, block_class)
396427

397428
@staticmethod
398429
def _add_block_slot_to_root_lookup(db: BaseDB, block: BaseBeaconBlock) -> None:
@@ -466,9 +497,8 @@ def get(self, key: bytes) -> bytes:
466497
# relatively expensive so we cache that here, but use a small cache because we *should* only
467498
# be looking up recent blocks.
468499
@functools.lru_cache(128)
469-
def _decode_block(block_rlp: bytes) -> BaseBeaconBlock:
470-
# TODO: forkable Block fields?
471-
return rlp.decode(block_rlp, sedes=BaseBeaconBlock)
500+
def _decode_block(block_rlp: bytes, sedes: Type[BaseBeaconBlock]) -> BaseBeaconBlock:
501+
return rlp.decode(block_rlp, sedes=sedes)
472502

473503

474504
@functools.lru_cache(128)

eth/beacon/state_machines/base.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
Type,
88
)
99

10+
from eth_typing import (
11+
Hash32,
12+
)
13+
1014
from eth._utils.datatypes import (
1115
Configurable,
1216
)
@@ -30,7 +34,7 @@ class BaseBeaconStateMachine(Configurable, ABC):
3034
config = None # type: BeaconConfig
3135

3236
block = None # type: BaseBeaconBlock
33-
state = None # type: BeaconState
37+
_state = None # type: BeaconState
3438

3539
block_class = None # type: Type[BaseBeaconBlock]
3640
state_class = None # type: Type[BeaconState]
@@ -67,10 +71,16 @@ def import_block(self, block: BaseBeaconBlock) -> Tuple[BeaconState, BaseBeaconB
6771
class BeaconStateMachine(BaseBeaconStateMachine):
6872
def __init__(self,
6973
chaindb: BaseBeaconChainDB,
70-
block: BaseBeaconBlock,
71-
state: BeaconState) -> None:
72-
# TODO: get state from DB, now it's just a stub!
73-
self.state = state
74+
block_root: Hash32) -> None:
75+
self.chaindb = chaindb
76+
self.chaindb.set_block_class(self.get_block_class())
77+
self.block = self.chaindb.get_block_by_root(block_root)
78+
79+
@property
80+
def state(self) -> BeaconState:
81+
if self._state is None:
82+
self._state = self.chaindb.get_state_by_root(self.block.state_root)
83+
return self._state
7484

7585
@classmethod
7686
def get_block_class(cls) -> Type[BaseBeaconBlock]:
Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
1-
from eth.beacon.types.blocks import BaseBeaconBlock
1+
from eth.beacon.types.blocks import (
2+
BeaconBlock,
3+
BeaconBlockBody,
4+
)
25

36

4-
class SerenityBeaconBlock(BaseBeaconBlock):
7+
class SerenityBeaconBlockBody(BeaconBlockBody):
8+
pass
9+
10+
11+
class SerenityBeaconBlock(BeaconBlock):
512
pass

0 commit comments

Comments
 (0)