4
4
from typing import (
5
5
Iterable ,
6
6
Tuple ,
7
+ Type ,
7
8
)
8
9
from cytoolz import (
9
10
first ,
38
39
)
39
40
40
41
from eth .beacon .types .states import BeaconState # noqa: F401
41
- from eth .beacon .types .blocks import BaseBeaconBlock # noqa: F401
42
+ from eth .beacon .types .blocks import ( # noqa: F401
43
+ BaseBeaconBlock ,
44
+ BeaconBlock ,
45
+ )
42
46
from eth .beacon .validation import (
43
47
validate_slot ,
44
48
)
48
52
49
53
class BaseBeaconChainDB (ABC ):
50
54
db = None # type: BaseAtomicDB
55
+ block_class = None # type: Type[BaseBeaconBlock]
56
+
57
+ @abstractmethod
58
+ def set_block_class (self , block_class : Type [BaseBeaconBlock ]) -> None :
59
+ pass
51
60
52
61
#
53
62
# Block API
@@ -117,24 +126,33 @@ def get(self, key: bytes) -> bytes:
117
126
118
127
119
128
class BeaconChainDB (BaseBeaconChainDB ):
120
- def __init__ (self , db : BaseAtomicDB ) -> None :
129
+ def __init__ (self , db : BaseAtomicDB , block_class : Type [ BaseBeaconBlock ] ) -> None :
121
130
self .db = db
131
+ self .block_class = block_class
132
+
133
+ def set_block_class (self , block_class : Type [BaseBeaconBlock ]) -> None :
134
+ self .block_class = block_class
122
135
123
136
def persist_block (self ,
124
137
block : BaseBeaconBlock ) -> Tuple [Tuple [bytes , ...], Tuple [bytes , ...]]:
125
138
"""
126
139
Persist the given block.
127
140
"""
128
141
with self .db .atomic_batch () as db :
129
- return self ._persist_block (db , block )
142
+ return self ._persist_block (db , block , self . block_class )
130
143
131
144
@classmethod
132
145
def _persist_block (
133
146
cls ,
134
147
db : 'BaseDB' ,
135
- block : BaseBeaconBlock ) -> Tuple [Tuple [bytes , ...], Tuple [bytes , ...]]:
148
+ block : BaseBeaconBlock ,
149
+ block_class : Type [BaseBeaconBlock ]) -> Tuple [Tuple [bytes , ...], Tuple [bytes , ...]]:
136
150
block_chain = (block , )
137
- new_canonical_blocks , old_canonical_blocks = cls ._persist_block_chain (db , block_chain )
151
+ new_canonical_blocks , old_canonical_blocks = cls ._persist_block_chain (
152
+ db ,
153
+ block_chain ,
154
+ block_class ,
155
+ )
138
156
139
157
return new_canonical_blocks , old_canonical_blocks
140
158
@@ -176,15 +194,16 @@ def get_canonical_block_by_slot(self, slot: int) -> BaseBeaconBlock:
176
194
Raise BlockNotFound if there's no block with the given slot in the
177
195
canonical chain.
178
196
"""
179
- return self ._get_canonical_block_by_slot (self .db , slot )
197
+ return self ._get_canonical_block_by_slot (self .db , slot , self . block_class )
180
198
181
199
@classmethod
182
200
def _get_canonical_block_by_slot (
183
201
cls ,
184
202
db : BaseDB ,
185
- slot : int ) -> BaseBeaconBlock :
203
+ slot : int ,
204
+ block_class : Type [BaseBeaconBlock ]) -> BaseBeaconBlock :
186
205
canonical_block_root = cls ._get_canonical_block_root_by_slot (db , slot )
187
- return cls ._get_block_by_root (db , canonical_block_root )
206
+ return cls ._get_block_by_root (db , canonical_block_root , block_class )
188
207
189
208
def get_canonical_block_root_by_slot (self , slot : int ) -> Hash32 :
190
209
"""
@@ -207,21 +226,25 @@ def get_canonical_head(self) -> BaseBeaconBlock:
207
226
"""
208
227
Return the current block at the head of the chain.
209
228
"""
210
- return self ._get_canonical_head (self .db )
229
+ return self ._get_canonical_head (self .db , self . block_class )
211
230
212
231
@classmethod
213
- def _get_canonical_head (cls , db : BaseDB ) -> BaseBeaconBlock :
232
+ def _get_canonical_head (cls ,
233
+ db : BaseDB ,
234
+ block_class : Type [BaseBeaconBlock ]) -> BaseBeaconBlock :
214
235
try :
215
236
canonical_head_root = db [SchemaV1 .make_canonical_head_root_lookup_key ()]
216
237
except KeyError :
217
238
raise CanonicalHeadNotFound ("No canonical head set for this chain" )
218
- return cls ._get_block_by_root (db , Hash32 (canonical_head_root ))
239
+ return cls ._get_block_by_root (db , Hash32 (canonical_head_root ), block_class )
219
240
220
241
def get_block_by_root (self , block_root : Hash32 ) -> BaseBeaconBlock :
221
- return self ._get_block_by_root (self .db , block_root )
242
+ return self ._get_block_by_root (self .db , block_root , self . block_class )
222
243
223
244
@staticmethod
224
- def _get_block_by_root (db : BaseDB , block_root : Hash32 ) -> BaseBeaconBlock :
245
+ def _get_block_by_root (db : BaseDB ,
246
+ block_root : Hash32 ,
247
+ block_class : Type [BaseBeaconBlock ]) -> BaseBeaconBlock :
225
248
"""
226
249
Return the requested block header as specified by block root.
227
250
@@ -233,7 +256,7 @@ def _get_block_by_root(db: BaseDB, block_root: Hash32) -> BaseBeaconBlock:
233
256
except KeyError :
234
257
raise BlockNotFound ("No block with root {0} found" .format (
235
258
encode_hex (block_root )))
236
- return _decode_block (block_rlp )
259
+ return _decode_block (block_rlp , block_class )
237
260
238
261
def get_score (self , block_root : Hash32 ) -> int :
239
262
return self ._get_score (self .db , block_root )
@@ -264,13 +287,14 @@ def persist_block_chain(
264
287
the second containing the old canonical headers
265
288
"""
266
289
with self .db .atomic_batch () as db :
267
- return self ._persist_block_chain (db , blocks )
290
+ return self ._persist_block_chain (db , blocks , self . block_class )
268
291
269
292
@classmethod
270
293
def _persist_block_chain (
271
294
cls ,
272
295
db : BaseDB ,
273
- blocks : Iterable [BaseBeaconBlock ]
296
+ blocks : Iterable [BaseBeaconBlock ],
297
+ block_class : Type [BaseBeaconBlock ]
274
298
) -> Tuple [Tuple [BaseBeaconBlock , ...], Tuple [BaseBeaconBlock , ...]]:
275
299
try :
276
300
first_block = first (blocks )
@@ -313,20 +337,23 @@ def _persist_block_chain(
313
337
)
314
338
315
339
try :
316
- previous_canonical_head = cls ._get_canonical_head (db ).root
340
+ previous_canonical_head = cls ._get_canonical_head (db , block_class ).root
317
341
head_score = cls ._get_score (db , previous_canonical_head )
318
342
except CanonicalHeadNotFound :
319
- return cls ._set_as_canonical_chain_head (db , block .root )
343
+ return cls ._set_as_canonical_chain_head (db , block .root , block_class )
320
344
321
345
if score > head_score :
322
- return cls ._set_as_canonical_chain_head (db , block .root )
346
+ return cls ._set_as_canonical_chain_head (db , block .root , block_class )
323
347
else :
324
348
return tuple (), tuple ()
325
349
326
350
@classmethod
327
351
def _set_as_canonical_chain_head (
328
- cls , db : BaseDB ,
329
- block_root : Hash32 ) -> Tuple [Tuple [BaseBeaconBlock , ...], Tuple [BaseBeaconBlock , ...]]:
352
+ cls ,
353
+ db : BaseDB ,
354
+ block_root : Hash32 ,
355
+ block_class : Type [BaseBeaconBlock ]
356
+ ) -> Tuple [Tuple [BaseBeaconBlock , ...], Tuple [BaseBeaconBlock , ...]]:
330
357
"""
331
358
Set the canonical chain HEAD to the block as specified by the
332
359
given block root.
@@ -335,13 +362,13 @@ def _set_as_canonical_chain_head(
335
362
are no longer in the canonical chain
336
363
"""
337
364
try :
338
- block = cls ._get_block_by_root (db , block_root )
365
+ block = cls ._get_block_by_root (db , block_root , block_class )
339
366
except BlockNotFound :
340
367
raise ValueError (
341
368
"Cannot use unknown block root as canonical head: {}" .format (block_root )
342
369
)
343
370
344
- new_canonical_blocks = tuple (reversed (cls ._find_new_ancestors (db , block )))
371
+ new_canonical_blocks = tuple (reversed (cls ._find_new_ancestors (db , block , block_class )))
345
372
old_canonical_blocks = []
346
373
347
374
for block in new_canonical_blocks :
@@ -351,7 +378,7 @@ def _set_as_canonical_chain_head(
351
378
# no old_canonical block, and no more possible
352
379
break
353
380
else :
354
- old_canonical_block = cls ._get_block_by_root (db , old_canonical_root )
381
+ old_canonical_block = cls ._get_block_by_root (db , old_canonical_root , block_class )
355
382
old_canonical_blocks .append (old_canonical_block )
356
383
357
384
for block in new_canonical_blocks :
@@ -363,7 +390,11 @@ def _set_as_canonical_chain_head(
363
390
364
391
@classmethod
365
392
@to_tuple
366
- def _find_new_ancestors (cls , db : BaseDB , block : BaseBeaconBlock ) -> Iterable [BaseBeaconBlock ]:
393
+ def _find_new_ancestors (
394
+ cls ,
395
+ db : BaseDB ,
396
+ block : BaseBeaconBlock ,
397
+ block_class : Type [BaseBeaconBlock ]) -> Iterable [BaseBeaconBlock ]:
367
398
"""
368
399
Return the chain leading up from the given block until (but not including)
369
400
the first ancestor it has in common with our canonical chain.
@@ -377,7 +408,7 @@ def _find_new_ancestors(cls, db: BaseDB, block: BaseBeaconBlock) -> Iterable[Bas
377
408
"""
378
409
while True :
379
410
try :
380
- orig = cls ._get_canonical_block_by_slot (db , block .slot )
411
+ orig = cls ._get_canonical_block_by_slot (db , block .slot , block_class )
381
412
except BlockNotFound :
382
413
# This just means the block is not on the canonical chain.
383
414
pass
@@ -392,7 +423,7 @@ def _find_new_ancestors(cls, db: BaseDB, block: BaseBeaconBlock) -> Iterable[Bas
392
423
if block .parent_root == GENESIS_PARENT_HASH :
393
424
break
394
425
else :
395
- block = cls ._get_block_by_root (db , block .parent_root )
426
+ block = cls ._get_block_by_root (db , block .parent_root , block_class )
396
427
397
428
@staticmethod
398
429
def _add_block_slot_to_root_lookup (db : BaseDB , block : BaseBeaconBlock ) -> None :
@@ -466,9 +497,8 @@ def get(self, key: bytes) -> bytes:
466
497
# relatively expensive so we cache that here, but use a small cache because we *should* only
467
498
# be looking up recent blocks.
468
499
@functools .lru_cache (128 )
469
- def _decode_block (block_rlp : bytes ) -> BaseBeaconBlock :
470
- # TODO: forkable Block fields?
471
- return rlp .decode (block_rlp , sedes = BaseBeaconBlock )
500
+ def _decode_block (block_rlp : bytes , sedes : Type [BaseBeaconBlock ]) -> BaseBeaconBlock :
501
+ return rlp .decode (block_rlp , sedes = sedes )
472
502
473
503
474
504
@functools .lru_cache (128 )
0 commit comments