|
1 | 1 | from abc import ABC, abstractmethod
|
2 | 2 | import functools
|
3 |
| -from typing import Tuple, Iterable |
| 3 | +from typing import Iterable, Tuple |
4 | 4 |
|
5 | 5 | import rlp
|
6 | 6 |
|
| 7 | +from cytoolz import ( |
| 8 | + first, |
| 9 | + sliding_window, |
| 10 | +) |
| 11 | + |
7 | 12 | from eth_utils import (
|
8 | 13 | encode_hex,
|
9 | 14 | to_tuple,
|
| 15 | + ValidationError, |
10 | 16 | )
|
11 | 17 |
|
12 | 18 | from eth_typing import (
|
@@ -73,6 +79,12 @@ def persist_header(self,
|
73 | 79 | ) -> Tuple[Tuple[BlockHeader, ...], Tuple[BlockHeader, ...]]:
|
74 | 80 | raise NotImplementedError("ChainDB classes must implement this method")
|
75 | 81 |
|
| 82 | + @abstractmethod |
| 83 | + def persist_header_chain(self, |
| 84 | + headers: Iterable[BlockHeader] |
| 85 | + ) -> Tuple[Tuple[BlockHeader, ...], Tuple[BlockHeader, ...]]: |
| 86 | + raise NotImplementedError("ChainDB classes must implement this method") |
| 87 | + |
76 | 88 |
|
77 | 89 | class HeaderDB(BaseHeaderDB):
|
78 | 90 | #
|
@@ -149,43 +161,66 @@ def header_exists(self, block_hash: Hash32) -> bool:
|
149 | 161 | def persist_header(self,
|
150 | 162 | header: BlockHeader
|
151 | 163 | ) -> Tuple[Tuple[BlockHeader, ...], Tuple[BlockHeader, ...]]:
|
| 164 | + return self.persist_header_chain((header,)) |
| 165 | + |
| 166 | + def persist_header_chain(self, |
| 167 | + headers: Iterable[BlockHeader] |
| 168 | + ) -> Tuple[Tuple[BlockHeader, ...], Tuple[BlockHeader, ...]]: |
152 | 169 | """
|
153 |
| - :returns: iterable of headers newly on the canonical chain |
| 170 | + Return two iterable of headers, the first containing the new canonical headers, |
| 171 | + the second containing the old canonical headers |
154 | 172 | """
|
155 |
| - if header.parent_hash != GENESIS_PARENT_HASH: |
156 |
| - try: |
157 |
| - self.get_block_header_by_hash(header.parent_hash) |
158 |
| - except HeaderNotFound: |
| 173 | + |
| 174 | + try: |
| 175 | + first_header = first(headers) |
| 176 | + except StopIteration: |
| 177 | + return tuple(), tuple() |
| 178 | + else: |
| 179 | + |
| 180 | + for parent, child in sliding_window(2, headers): |
| 181 | + if parent.hash != child.parent_hash: |
| 182 | + raise ValidationError( |
| 183 | + "Non-contiguous chain. Expected {} to have {} as parent but was {}".format( |
| 184 | + encode_hex(child.hash), |
| 185 | + encode_hex(parent.hash), |
| 186 | + encode_hex(child.parent_hash), |
| 187 | + ) |
| 188 | + ) |
| 189 | + |
| 190 | + is_genesis = first_header.parent_hash == GENESIS_PARENT_HASH |
| 191 | + if not is_genesis and not self.header_exists(first_header.parent_hash): |
159 | 192 | raise ParentNotFound(
|
160 | 193 | "Cannot persist block header ({}) with unknown parent ({})".format(
|
161 |
| - encode_hex(header.hash), encode_hex(header.parent_hash))) |
| 194 | + encode_hex(first_header.hash), encode_hex(first_header.parent_hash))) |
162 | 195 |
|
163 |
| - self.db.set( |
164 |
| - header.hash, |
165 |
| - rlp.encode(header), |
166 |
| - ) |
| 196 | + score = 0 if is_genesis else self.get_score(first_header.parent_hash) |
167 | 197 |
|
168 |
| - if header.parent_hash == GENESIS_PARENT_HASH: |
169 |
| - score = header.difficulty |
170 |
| - else: |
171 |
| - score = self.get_score(header.parent_hash) + header.difficulty |
| 198 | + for header in headers: |
| 199 | + self.db.set( |
| 200 | + header.hash, |
| 201 | + rlp.encode(header), |
| 202 | + ) |
172 | 203 |
|
173 |
| - self.db.set( |
174 |
| - SchemaV1.make_block_hash_to_score_lookup_key(header.hash), |
175 |
| - rlp.encode(score, sedes=rlp.sedes.big_endian_int), |
176 |
| - ) |
| 204 | + score += header.difficulty |
| 205 | + |
| 206 | + self.db.set( |
| 207 | + SchemaV1.make_block_hash_to_score_lookup_key(header.hash), |
| 208 | + rlp.encode(score, sedes=rlp.sedes.big_endian_int), |
| 209 | + ) |
177 | 210 |
|
178 | 211 | try:
|
179 | 212 | head_score = self.get_score(self.get_canonical_head().hash)
|
180 | 213 | except CanonicalHeadNotFound:
|
181 |
| - new_canonical_headers, old_canonical_headers = self._set_as_canonical_chain_head( |
182 |
| - header.hash, |
183 |
| - ) |
| 214 | + ( |
| 215 | + new_canonical_headers, |
| 216 | + old_canonical_headers |
| 217 | + ) = self._set_as_canonical_chain_head(header.hash) |
184 | 218 | else:
|
185 | 219 | if score > head_score:
|
186 |
| - new_canonical_headers, old_canonical_headers = self._set_as_canonical_chain_head( |
187 |
| - header.hash |
188 |
| - ) |
| 220 | + ( |
| 221 | + new_canonical_headers, |
| 222 | + old_canonical_headers |
| 223 | + ) = self._set_as_canonical_chain_head(header.hash) |
189 | 224 | else:
|
190 | 225 | new_canonical_headers = tuple()
|
191 | 226 | old_canonical_headers = tuple()
|
@@ -297,6 +332,10 @@ async def coro_get_canonical_block_hash(self, block_number: BlockNumber) -> Hash
|
297 | 332 | async def coro_persist_header(self, header: BlockHeader) -> Tuple[BlockHeader, ...]:
|
298 | 333 | raise NotImplementedError()
|
299 | 334 |
|
| 335 | + async def coro_persist_header_chain(self, |
| 336 | + headers: Iterable[BlockHeader]) -> Tuple[BlockHeader, ...]: |
| 337 | + raise NotImplementedError() |
| 338 | + |
300 | 339 |
|
301 | 340 | # When performing a chain sync (either fast or regular modes), we'll very often need to look
|
302 | 341 | # up recent block headers to validate the chain, and decoding their RLP representation is
|
|
0 commit comments