Skip to content

Commit 1408cfe

Browse files
committed
Add persist_chain_header API to HeaderDB
1 parent 35d6d15 commit 1408cfe

File tree

6 files changed

+99
-90
lines changed

6 files changed

+99
-90
lines changed

eth/db/chain.py

Lines changed: 2 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,9 @@
3434

3535
from eth.constants import (
3636
EMPTY_UNCLE_HASH,
37-
GENESIS_PARENT_HASH,
3837
)
3938
from eth.exceptions import (
40-
CanonicalHeadNotFound,
4139
HeaderNotFound,
42-
ParentNotFound,
4340
TransactionNotFound,
4441
)
4542
from eth.db.header import BaseHeaderDB, HeaderDB
@@ -186,63 +183,14 @@ def get_block_uncles(self, uncles_hash: Hash32) -> List[BlockHeader]:
186183
else:
187184
return rlp.decode(encoded_uncles, sedes=rlp.sedes.CountableList(BlockHeader))
188185

189-
# TODO: This method should take a chain of headers as that's the most common use case
190-
# and it'd be much faster than inserting each header individually.
191-
def persist_header(self,
192-
header: BlockHeader
193-
) -> Tuple[Tuple[BlockHeader, ...], Tuple[BlockHeader, ...]]:
194-
"""
195-
Returns iterable of headers newly on the canonical chain
196-
"""
197-
is_genesis = header.parent_hash == GENESIS_PARENT_HASH
198-
if not is_genesis and not self.header_exists(header.parent_hash):
199-
raise ParentNotFound(
200-
"Cannot persist block header ({}) with unknown parent ({})".format(
201-
encode_hex(header.hash), encode_hex(header.parent_hash)))
202-
203-
self.db.set(
204-
header.hash,
205-
rlp.encode(header),
206-
)
207-
208-
if is_genesis:
209-
score = header.difficulty
210-
else:
211-
score = self.get_score(header.parent_hash) + header.difficulty
212-
213-
self.db.set(
214-
SchemaV1.make_block_hash_to_score_lookup_key(header.hash),
215-
rlp.encode(score, sedes=rlp.sedes.big_endian_int),
216-
)
217-
218-
try:
219-
head_score = self.get_score(self.get_canonical_head().hash)
220-
except CanonicalHeadNotFound:
221-
(
222-
new_canonical_headers,
223-
old_canonical_headers
224-
) = self._set_as_canonical_chain_head(header)
225-
else:
226-
if score > head_score:
227-
(
228-
new_canonical_headers,
229-
old_canonical_headers
230-
) = self._set_as_canonical_chain_head(header)
231-
else:
232-
new_canonical_headers = tuple()
233-
old_canonical_headers = tuple()
234-
235-
return new_canonical_headers, old_canonical_headers
236-
237-
# TODO: update this to take a `hash` rather than a full header object.
238186
def _set_as_canonical_chain_head(self,
239-
header: BlockHeader
187+
block_hash: Hash32
240188
) -> Tuple[Tuple[BlockHeader, ...], Tuple[BlockHeader, ...]]:
241189
"""
242190
Returns iterable of headers newly on the canonical head
243191
"""
244192
try:
245-
self.get_block_header_by_hash(header.hash)
193+
header = self.get_block_header_by_hash(block_hash)
246194
except HeaderNotFound:
247195
raise ValueError("Cannot use unknown block hash as canonical head: {}".format(
248196
header.hash))

eth/db/header.py

Lines changed: 64 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
from abc import ABC, abstractmethod
22
import functools
3-
from typing import Tuple, Iterable
3+
from typing import Iterable, Tuple
44

55
import rlp
66

7+
from cytoolz import (
8+
first,
9+
sliding_window,
10+
)
11+
712
from eth_utils import (
813
encode_hex,
914
to_tuple,
15+
ValidationError,
1016
)
1117

1218
from eth_typing import (
@@ -73,6 +79,12 @@ def persist_header(self,
7379
) -> Tuple[Tuple[BlockHeader, ...], Tuple[BlockHeader, ...]]:
7480
raise NotImplementedError("ChainDB classes must implement this method")
7581

82+
@abstractmethod
83+
def persist_header_chain(self,
84+
headers: Iterable[BlockHeader]
85+
) -> Tuple[Tuple[BlockHeader, ...], Tuple[BlockHeader, ...]]:
86+
raise NotImplementedError("ChainDB classes must implement this method")
87+
7688

7789
class HeaderDB(BaseHeaderDB):
7890
#
@@ -149,43 +161,66 @@ def header_exists(self, block_hash: Hash32) -> bool:
149161
def persist_header(self,
150162
header: BlockHeader
151163
) -> Tuple[Tuple[BlockHeader, ...], Tuple[BlockHeader, ...]]:
164+
return self.persist_header_chain((header,))
165+
166+
def persist_header_chain(self,
167+
headers: Iterable[BlockHeader]
168+
) -> Tuple[Tuple[BlockHeader, ...], Tuple[BlockHeader, ...]]:
152169
"""
153-
:returns: iterable of headers newly on the canonical chain
170+
Return two iterable of headers, the first containing the new canonical headers,
171+
the second containing the old canonical headers
154172
"""
155-
if header.parent_hash != GENESIS_PARENT_HASH:
156-
try:
157-
self.get_block_header_by_hash(header.parent_hash)
158-
except HeaderNotFound:
173+
174+
try:
175+
first_header = first(headers)
176+
except StopIteration:
177+
return tuple(), tuple()
178+
else:
179+
180+
for parent, child in sliding_window(2, headers):
181+
if parent.hash != child.parent_hash:
182+
raise ValidationError(
183+
"Non-contiguous chain. Expected {} to have {} as parent but was {}".format(
184+
encode_hex(child.hash),
185+
encode_hex(parent.hash),
186+
encode_hex(child.parent_hash),
187+
)
188+
)
189+
190+
is_genesis = first_header.parent_hash == GENESIS_PARENT_HASH
191+
if not is_genesis and not self.header_exists(first_header.parent_hash):
159192
raise ParentNotFound(
160193
"Cannot persist block header ({}) with unknown parent ({})".format(
161-
encode_hex(header.hash), encode_hex(header.parent_hash)))
194+
encode_hex(first_header.hash), encode_hex(first_header.parent_hash)))
162195

163-
self.db.set(
164-
header.hash,
165-
rlp.encode(header),
166-
)
196+
score = 0 if is_genesis else self.get_score(first_header.parent_hash)
167197

168-
if header.parent_hash == GENESIS_PARENT_HASH:
169-
score = header.difficulty
170-
else:
171-
score = self.get_score(header.parent_hash) + header.difficulty
198+
for header in headers:
199+
self.db.set(
200+
header.hash,
201+
rlp.encode(header),
202+
)
172203

173-
self.db.set(
174-
SchemaV1.make_block_hash_to_score_lookup_key(header.hash),
175-
rlp.encode(score, sedes=rlp.sedes.big_endian_int),
176-
)
204+
score += header.difficulty
205+
206+
self.db.set(
207+
SchemaV1.make_block_hash_to_score_lookup_key(header.hash),
208+
rlp.encode(score, sedes=rlp.sedes.big_endian_int),
209+
)
177210

178211
try:
179212
head_score = self.get_score(self.get_canonical_head().hash)
180213
except CanonicalHeadNotFound:
181-
new_canonical_headers, old_canonical_headers = self._set_as_canonical_chain_head(
182-
header.hash,
183-
)
214+
(
215+
new_canonical_headers,
216+
old_canonical_headers
217+
) = self._set_as_canonical_chain_head(header.hash)
184218
else:
185219
if score > head_score:
186-
new_canonical_headers, old_canonical_headers = self._set_as_canonical_chain_head(
187-
header.hash
188-
)
220+
(
221+
new_canonical_headers,
222+
old_canonical_headers
223+
) = self._set_as_canonical_chain_head(header.hash)
189224
else:
190225
new_canonical_headers = tuple()
191226
old_canonical_headers = tuple()
@@ -297,6 +332,10 @@ async def coro_get_canonical_block_hash(self, block_number: BlockNumber) -> Hash
297332
async def coro_persist_header(self, header: BlockHeader) -> Tuple[BlockHeader, ...]:
298333
raise NotImplementedError()
299334

335+
async def coro_persist_header_chain(self,
336+
headers: Iterable[BlockHeader]) -> Tuple[BlockHeader, ...]:
337+
raise NotImplementedError()
338+
300339

301340
# When performing a chain sync (either fast or regular modes), we'll very often need to look
302341
# up recent block headers to validate the chain, and decoding their RLP representation is

tests/database/test_header_db.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from eth_utils import (
99
to_tuple,
1010
keccak,
11+
ValidationError,
1112
)
1213

1314
from eth.constants import (
@@ -80,8 +81,7 @@ def test_headerdb_get_canonical_head_with_header_chain(headerdb, genesis_header)
8081

8182
headers = mk_header_chain(genesis_header, length=10)
8283

83-
for header in headers:
84-
headerdb.persist_header(header)
84+
headerdb.persist_header_chain(headers)
8585

8686
head = headerdb.get_canonical_head()
8787
assert_headers_eq(head, headers[-1])
@@ -98,6 +98,17 @@ def test_headerdb_persist_header_disallows_unknown_parent(headerdb):
9898
headerdb.persist_header(header)
9999

100100

101+
def test_headerdb_persist_header_chain_disallows_non_contiguous_chain(headerdb, genesis_header):
102+
headerdb.persist_header(genesis_header)
103+
104+
headers = mk_header_chain(genesis_header, length=3)
105+
106+
non_contiguous_headers = (headers[0], headers[2], headers[1],)
107+
108+
with pytest.raises(ValidationError, match="Non-contiguous chain"):
109+
headerdb.persist_header_chain(non_contiguous_headers)
110+
111+
101112
def test_headerdb_persist_header_returns_new_canonical_chain(headerdb, genesis_header):
102113
gen_result, _ = headerdb.persist_header(genesis_header)
103114
assert gen_result == (genesis_header,)
@@ -144,8 +155,7 @@ def test_headerdb_get_score_for_non_genesis_headers(headerdb, genesis_header):
144155
difficulties = tuple(h.difficulty for h in headers)
145156
scores = tuple(accumulate(operator.add, difficulties, genesis_header.difficulty))
146157

147-
for header in headers:
148-
headerdb.persist_header(header)
158+
headerdb.persist_header_chain(headers)
149159

150160
for header, expected_score in zip(headers, scores[1:]):
151161
actual_score = headerdb.get_score(header.hash)
@@ -212,8 +222,7 @@ def test_headerdb_header_retrieval_by_hash(headerdb, genesis_header):
212222

213223
headers = mk_header_chain(genesis_header, length=10)
214224

215-
for header in headers:
216-
headerdb.persist_header(header)
225+
headerdb.persist_header_chain(headers)
217226

218227
# can we get the genesis header by hash
219228
actual = headerdb.get_block_header_by_hash(genesis_header.hash)
@@ -229,8 +238,7 @@ def test_headerdb_canonical_header_retrieval_by_number(headerdb, genesis_header)
229238

230239
headers = mk_header_chain(genesis_header, length=10)
231240

232-
for header in headers:
233-
headerdb.persist_header(header)
241+
headerdb.persist_header_chain(headers)
234242

235243
# can we get the genesis header by hash
236244
actual = headerdb.get_canonical_block_header_by_number(genesis_header.block_number)

tests/trinity/core/integration_test_helpers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class FakeAsyncHeaderDB(AsyncHeaderDB):
5252
coro_get_score = async_passthrough('get_score')
5353
coro_header_exists = async_passthrough('header_exists')
5454
coro_persist_header = async_passthrough('persist_header')
55+
coro_persist_header_chain = async_passthrough('persist_header_chain')
5556

5657

5758
class FakeAsyncChainDB(FakeAsyncHeaderDB, AsyncChainDB):

trinity/db/header.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
from multiprocessing.managers import ( # type: ignore
55
BaseProxy,
66
)
7-
from typing import Tuple
7+
from typing import (
8+
Iterable,
9+
Tuple,
10+
)
811

912
from eth_typing import (
1013
Hash32,
@@ -58,6 +61,11 @@ async def coro_header_exists(self, block_hash: Hash32) -> bool:
5861
async def coro_persist_header(self, header: BlockHeader) -> Tuple[BlockHeader, ...]:
5962
raise NotImplementedError("ChainDB classes must implement this method")
6063

64+
@abstractmethod
65+
async def coro_persist_header_chain(self,
66+
headers: Iterable[BlockHeader]) -> Tuple[BlockHeader, ...]:
67+
raise NotImplementedError("ChainDB classes must implement this method")
68+
6169

6270
class AsyncHeaderDB(HeaderDB, BaseAsyncHeaderDB):
6371
async def coro_get_canonical_block_hash(self, block_number: BlockNumber) -> Hash32:
@@ -81,6 +89,10 @@ async def coro_header_exists(self, block_hash: Hash32) -> bool:
8189
async def coro_persist_header(self, header: BlockHeader) -> Tuple[BlockHeader, ...]:
8290
raise NotImplementedError("ChainDB classes must implement this method")
8391

92+
async def coro_persist_header_chain(self,
93+
headers: Iterable[BlockHeader]) -> Tuple[BlockHeader, ...]:
94+
raise NotImplementedError("ChainDB classes must implement this method")
95+
8496

8597
class AsyncHeaderDBProxy(BaseProxy, BaseAsyncHeaderDB, BaseHeaderDB):
8698
coro_get_block_header_by_hash = async_method('get_block_header_by_hash')
@@ -91,6 +103,7 @@ class AsyncHeaderDBProxy(BaseProxy, BaseAsyncHeaderDB, BaseHeaderDB):
91103
coro_header_exists = async_method('header_exists')
92104
coro_get_canonical_block_hash = async_method('get_canonical_block_hash')
93105
coro_persist_header = async_method('persist_header')
106+
coro_persist_header_chain = async_method('persist_header_chain')
94107

95108
get_block_header_by_hash = sync_method('get_block_header_by_hash')
96109
get_canonical_block_hash = sync_method('get_canonical_block_hash')
@@ -100,3 +113,4 @@ class AsyncHeaderDBProxy(BaseProxy, BaseAsyncHeaderDB, BaseHeaderDB):
100113
header_exists = sync_method('header_exists')
101114
get_canonical_block_hash = sync_method('get_canonical_block_hash')
102115
persist_header = sync_method('persist_header')
116+
persist_header_chain = sync_method('persist_header_chain')

trinity/sync/light/chain.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ async def _persist_headers(self) -> None:
6666
batch_id, headers = await self.wait(self.header_queue.get())
6767

6868
timer = Timer()
69-
for header in headers:
70-
await self.wait(self.db.coro_persist_header(header))
69+
await self.wait(self.db.coro_persist_header_chain(headers))
7170

7271
head = await self.wait(self.db.coro_get_canonical_head())
7372
self.logger.info(

0 commit comments

Comments
 (0)