Skip to content

Commit ee8f72d

Browse files
authored
Merge pull request #1689 from hwwhww/get_merkle_root
Add `get_merkle_root`
2 parents 000333a + 88445f1 commit ee8f72d

File tree

4 files changed

+109
-59
lines changed

4 files changed

+109
-59
lines changed

eth/_utils/blobs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
zpad_right,
1818
)
1919
from eth._utils.merkle import (
20-
calc_merkle_root,
20+
get_merkle_root_from_items,
2121
)
2222

2323
from eth.constants import (
@@ -44,7 +44,7 @@ def iterate_chunks(collation_body: bytes) -> Iterator[Hash32]:
4444
def calc_chunk_root(collation_body: bytes) -> Hash32:
4545
check_body_size(collation_body)
4646
chunks = list(iterate_chunks(collation_body))
47-
return calc_merkle_root(chunks)
47+
return get_merkle_root_from_items(chunks)
4848

4949

5050
def check_body_size(body: bytes) -> bytes:

eth/_utils/merkle.py

Lines changed: 61 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88
import math
99
from typing import (
10-
cast,
11-
Hashable,
10+
Iterable,
1211
NewType,
1312
Sequence,
13+
Union,
1414
)
1515

1616
from cytoolz import (
@@ -20,8 +20,8 @@
2020
reduce,
2121
take,
2222
)
23-
from eth_hash.auto import (
24-
keccak,
23+
from eth.beacon._utils.hash import (
24+
hash_eth2,
2525
)
2626
from eth_typing import (
2727
Hash32,
@@ -36,17 +36,23 @@
3636

3737

3838
def get_root(tree: MerkleTree) -> Hash32:
39-
"""Get the root hash of a Merkle tree."""
39+
"""
40+
Get the root hash of a Merkle tree.
41+
"""
4042
return tree[0][0]
4143

4244

43-
def get_branch_indices(node_index: int, depth: int) -> Sequence[int]:
44-
"""Get the indices of all ancestors up until the root for a node with a given depth."""
45+
def get_branch_indices(node_index: int, depth: int) -> Iterable[int]:
46+
"""
47+
Get the indices of all ancestors up until the root for a node with a given depth.
48+
"""
4549
return tuple(take(depth, iterate(lambda index: index // 2, node_index)))
4650

4751

48-
def get_merkle_proof(tree: MerkleTree, item_index: int) -> Sequence[Hash32]:
49-
"""Read off the Merkle proof for an item from a Merkle tree."""
52+
def get_merkle_proof(tree: MerkleTree, item_index: int) -> Iterable[Hash32]:
53+
"""
54+
Read off the Merkle proof for an item from a Merkle tree.
55+
"""
5056
if item_index < 0 or item_index >= len(tree[-1]):
5157
raise ValidationError("Item index out of range")
5258

@@ -64,16 +70,20 @@ def get_merkle_proof(tree: MerkleTree, item_index: int) -> Sequence[Hash32]:
6470

6571

6672
def _calc_parent_hash(left_node: Hash32, right_node: Hash32) -> Hash32:
67-
"""Calculate the parent hash of a node and its sibling."""
68-
return keccak(left_node + right_node)
73+
"""
74+
Calculate the parent hash of a node and its sibling.
75+
"""
76+
return hash_eth2(left_node + right_node)
6977

7078

7179
def verify_merkle_proof(root: Hash32,
72-
item: Hashable,
80+
item: Union[bytes, bytearray],
7381
item_index: int,
7482
proof: MerkleProof) -> bool:
75-
"""Verify a Merkle proof against a root hash."""
76-
leaf = keccak(item)
83+
"""
84+
Verify a Merkle proof against a root hash.
85+
"""
86+
leaf = hash_eth2(item)
7787
branch_indices = get_branch_indices(item_index, len(proof))
7888
node_orderers = [
7989
identity if branch_index % 2 == 0 else reversed
@@ -87,28 +97,51 @@ def verify_merkle_proof(root: Hash32,
8797
return proof_root == root
8898

8999

90-
def _hash_layer(layer: Sequence[Hash32]) -> Sequence[Hash32]:
91-
"""Calculate the layer on top of another one."""
92-
return tuple(_calc_parent_hash(left, right) for left, right in partition(2, layer))
100+
def _hash_layer(layer: Sequence[Hash32]) -> Iterable[Hash32]:
101+
"""
102+
Calculate the layer on top of another one.
103+
"""
104+
return tuple(
105+
_calc_parent_hash(left, right)
106+
for left, right in partition(2, layer)
107+
)
108+
93109

110+
def calc_merkle_tree(items: Sequence[Union[bytes, bytearray]]) -> MerkleTree:
111+
"""
112+
Calculate the Merkle tree corresponding to a list of items.
113+
"""
114+
leaves = tuple(hash_eth2(item) for item in items)
115+
return calc_merkle_tree_from_leaves(leaves)
94116

95-
def calc_merkle_tree(items: Sequence[Hashable]) -> MerkleTree:
96-
"""Calculate the Merkle tree corresponding to a list of items."""
97-
if len(items) == 0:
98-
raise ValidationError("No items given")
99-
n_layers = math.log2(len(items)) + 1
117+
118+
def get_merkle_root_from_items(items: Sequence[Union[bytes, bytearray]]) -> Hash32:
119+
"""
120+
Calculate the Merkle root corresponding to a list of items.
121+
"""
122+
return get_root(calc_merkle_tree(items))
123+
124+
125+
def calc_merkle_tree_from_leaves(leaves: Sequence[Hash32]) -> MerkleTree:
126+
if len(leaves) == 0:
127+
raise ValueError("No leaves given")
128+
n_layers = math.log2(len(leaves)) + 1
100129
if not n_layers.is_integer():
101-
raise ValidationError("Item number is not a power of two")
130+
raise ValueError("Number of leaves is not a power of two")
102131
n_layers = int(n_layers)
103132

104-
leaves = tuple(keccak(item) for item in items)
105-
tree = cast(MerkleTree, tuple(take(n_layers, iterate(_hash_layer, leaves)))[::-1])
133+
reversed_tree = tuple(take(n_layers, iterate(_hash_layer, leaves)))
134+
tree = MerkleTree(tuple(reversed(reversed_tree)))
135+
106136
if len(tree[0]) != 1:
107137
raise Exception("Invariant: There must only be one root")
108138

109139
return tree
110140

111141

112-
def calc_merkle_root(items: Sequence[Hashable]) -> Hash32:
113-
"""Calculate the Merkle root corresponding to a list of items."""
114-
return get_root(calc_merkle_tree(items))
142+
def get_merkle_root(leaves: Sequence[Hash32]) -> Hash32:
143+
"""
144+
Return the Merkle root of the given 32-byte hashes.
145+
Note: it has to be a full tree, i.e., `len(values)` is an exact power of 2.
146+
"""
147+
return get_root(calc_merkle_tree_from_leaves(leaves))

eth/beacon/_utils/hash.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
from typing import (
2+
Union,
3+
)
4+
15
from eth_typing import Hash32
26
from eth_hash.auto import keccak
37

48

5-
def hash_eth2(data: bytes) -> Hash32:
9+
def hash_eth2(data: Union[bytes, bytearray]) -> Hash32:
610
"""
711
Return Keccak-256 hashed result.
812
Note: it's a placeholder and we aim to migrate to a S[T/N]ARK-friendly hash function in

tests/core/merkle-utils/test_merkle_trees.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@
44
ValidationError,
55
)
66

7-
from eth_hash.auto import (
8-
keccak,
7+
from eth.beacon._utils.hash import (
8+
hash_eth2,
99
)
1010

1111
from eth._utils.merkle import (
12-
calc_merkle_root,
12+
get_merkle_root_from_items,
1313
calc_merkle_tree,
1414
get_root,
1515
get_merkle_proof,
16+
get_merkle_root,
1617
verify_merkle_proof,
1718
)
1819

@@ -21,41 +22,41 @@
2122
(
2223
(b"single leaf",),
2324
(
24-
(keccak(b"single leaf"),),
25+
(hash_eth2(b"single leaf"),),
2526
),
2627
),
2728
(
2829
(b"left", b"right"),
2930
(
30-
(keccak(keccak(b"left") + keccak(b"right")),),
31-
(keccak(b"left"), keccak(b"right")),
31+
(hash_eth2(hash_eth2(b"left") + hash_eth2(b"right")),),
32+
(hash_eth2(b"left"), hash_eth2(b"right")),
3233
),
3334
),
3435
(
3536
(b"1", b"2", b"3", b"4"),
3637
(
3738
(
38-
keccak(
39-
keccak(
40-
keccak(b"1") + keccak(b"2")
41-
) + keccak(
42-
keccak(b"3") + keccak(b"4")
39+
hash_eth2(
40+
hash_eth2(
41+
hash_eth2(b"1") + hash_eth2(b"2")
42+
) + hash_eth2(
43+
hash_eth2(b"3") + hash_eth2(b"4")
4344
)
4445
),
4546
),
4647
(
47-
keccak(
48-
keccak(b"1") + keccak(b"2")
48+
hash_eth2(
49+
hash_eth2(b"1") + hash_eth2(b"2")
4950
),
50-
keccak(
51-
keccak(b"3") + keccak(b"4")
51+
hash_eth2(
52+
hash_eth2(b"3") + hash_eth2(b"4")
5253
),
5354
),
5455
(
55-
keccak(b"1"),
56-
keccak(b"2"),
57-
keccak(b"3"),
58-
keccak(b"4"),
56+
hash_eth2(b"1"),
57+
hash_eth2(b"2"),
58+
hash_eth2(b"3"),
59+
hash_eth2(b"4"),
5960
),
6061
),
6162
),
@@ -64,45 +65,45 @@ def test_merkle_tree_calculation(leaves, tree):
6465
calculated_tree = calc_merkle_tree(leaves)
6566
assert calculated_tree == tree
6667
assert get_root(tree) == tree[0][0]
67-
assert calc_merkle_root(leaves) == get_root(tree)
68+
assert get_merkle_root_from_items(leaves) == get_root(tree)
6869

6970

7071
@pytest.mark.parametrize("leave_number", [0, 3, 5, 6, 7, 9])
7172
def test_invalid_merkle_root_calculation(leave_number):
72-
with pytest.raises(ValidationError):
73-
calc_merkle_root((b"",) * leave_number)
73+
with pytest.raises(ValueError):
74+
get_merkle_root_from_items((b"",) * leave_number)
7475

7576

7677
@pytest.mark.parametrize("leaves,index,proof", [
7778
(
7879
(b"1", b"2"),
7980
0,
80-
(keccak(b"2"),),
81+
(hash_eth2(b"2"),),
8182
),
8283
(
8384
(b"1", b"2"),
8485
1,
85-
(keccak(b"1"),),
86+
(hash_eth2(b"1"),),
8687
),
8788
(
8889
(b"1", b"2", b"3", b"4"),
8990
0,
90-
(keccak(b"2"), keccak(keccak(b"3") + keccak(b"4"))),
91+
(hash_eth2(b"2"), hash_eth2(hash_eth2(b"3") + hash_eth2(b"4"))),
9192
),
9293
(
9394
(b"1", b"2", b"3", b"4"),
9495
1,
95-
(keccak(b"1"), keccak(keccak(b"3") + keccak(b"4"))),
96+
(hash_eth2(b"1"), hash_eth2(hash_eth2(b"3") + hash_eth2(b"4"))),
9697
),
9798
(
9899
(b"1", b"2", b"3", b"4"),
99100
2,
100-
(keccak(b"4"), keccak(keccak(b"1") + keccak(b"2"))),
101+
(hash_eth2(b"4"), hash_eth2(hash_eth2(b"1") + hash_eth2(b"2"))),
101102
),
102103
(
103104
(b"1", b"2", b"3", b"4"),
104105
3,
105-
(keccak(b"3"), keccak(keccak(b"1") + keccak(b"2"))),
106+
(hash_eth2(b"3"), hash_eth2(hash_eth2(b"1") + hash_eth2(b"2"))),
106107
),
107108
])
108109
def test_merkle_proofs(leaves, index, proof):
@@ -142,3 +143,15 @@ def test_proof_generation_index_validation(leaves):
142143
for invalid_index in [-1, len(leaves)]:
143144
with pytest.raises(ValidationError):
144145
get_merkle_proof(tree, invalid_index)
146+
147+
148+
def test_get_merkle_root():
149+
hash_0 = b"0" * 32
150+
leaves = (hash_0,)
151+
root = get_merkle_root(leaves)
152+
assert root == hash_0
153+
154+
hash_1 = b"1" * 32
155+
leaves = (hash_0, hash_1)
156+
root = get_merkle_root(leaves)
157+
assert root == hash_eth2(hash_0 + hash_1)

0 commit comments

Comments
 (0)