Skip to content

Commit cb8242b

Browse files
committed
Add get_merkle_root
1 parent 000333a commit cb8242b

File tree

3 files changed

+107
-51
lines changed

3 files changed

+107
-51
lines changed

eth/_utils/merkle.py

Lines changed: 64 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import (
1010
cast,
1111
Hashable,
12+
Iterable,
1213
NewType,
1314
Sequence,
1415
)
@@ -20,8 +21,8 @@
2021
reduce,
2122
take,
2223
)
23-
from eth_hash.auto import (
24-
keccak,
24+
from eth.beacon._utils.hash import (
25+
hash_eth2,
2526
)
2627
from eth_typing import (
2728
Hash32,
@@ -36,17 +37,23 @@
3637

3738

3839
def get_root(tree: MerkleTree) -> Hash32:
39-
"""Get the root hash of a Merkle tree."""
40+
"""
41+
Get the root hash of a Merkle tree.
42+
"""
4043
return tree[0][0]
4144

4245

43-
def get_branch_indices(node_index: int, depth: int) -> Sequence[int]:
44-
"""Get the indices of all ancestors up until the root for a node with a given depth."""
46+
def get_branch_indices(node_index: int, depth: int) -> Iterable[int]:
47+
"""
48+
Get the indices of all ancestors up until the root for a node with a given depth.
49+
"""
4550
return tuple(take(depth, iterate(lambda index: index // 2, node_index)))
4651

4752

48-
def get_merkle_proof(tree: MerkleTree, item_index: int) -> Sequence[Hash32]:
49-
"""Read off the Merkle proof for an item from a Merkle tree."""
53+
def get_merkle_proof(tree: MerkleTree, item_index: int) -> Iterable[Hash32]:
54+
"""
55+
Read off the Merkle proof for an item from a Merkle tree.
56+
"""
5057
if item_index < 0 or item_index >= len(tree[-1]):
5158
raise ValidationError("Item index out of range")
5259

@@ -64,16 +71,20 @@ def get_merkle_proof(tree: MerkleTree, item_index: int) -> Sequence[Hash32]:
6471

6572

6673
def _calc_parent_hash(left_node: Hash32, right_node: Hash32) -> Hash32:
67-
"""Calculate the parent hash of a node and its sibling."""
68-
return keccak(left_node + right_node)
74+
"""
75+
Calculate the parent hash of a node and its sibling.
76+
"""
77+
return hash_eth2(left_node + right_node)
6978

7079

7180
def verify_merkle_proof(root: Hash32,
7281
item: Hashable,
7382
item_index: int,
7483
proof: MerkleProof) -> bool:
75-
"""Verify a Merkle proof against a root hash."""
76-
leaf = keccak(item)
84+
"""
85+
Verify a Merkle proof against a root hash.
86+
"""
87+
leaf = hash_eth2(item)
7788
branch_indices = get_branch_indices(item_index, len(proof))
7889
node_orderers = [
7990
identity if branch_index % 2 == 0 else reversed
@@ -87,28 +98,56 @@ def verify_merkle_proof(root: Hash32,
8798
return proof_root == root
8899

89100

90-
def _hash_layer(layer: Sequence[Hash32]) -> Sequence[Hash32]:
91-
"""Calculate the layer on top of another one."""
92-
return tuple(_calc_parent_hash(left, right) for left, right in partition(2, layer))
101+
def _hash_layer(layer: Sequence[Hash32]) -> Iterable[Hash32]:
102+
"""
103+
Calculate the layer on top of another one.
104+
"""
105+
return tuple(
106+
_calc_parent_hash(left, right)
107+
for left, right in partition(2, layer)
108+
)
93109

94110

95111
def calc_merkle_tree(items: Sequence[Hashable]) -> MerkleTree:
96-
"""Calculate the Merkle tree corresponding to a list of items."""
97-
if len(items) == 0:
98-
raise ValidationError("No items given")
99-
n_layers = math.log2(len(items)) + 1
112+
"""
113+
Calculate the Merkle tree corresponding to a list of items.
114+
"""
115+
leaves = tuple(hash_eth2(item) for item in items)
116+
return calc_merkle_tree_from_leaves(leaves)
117+
118+
119+
def calc_merkle_root(items: Sequence[Hashable]) -> Hash32:
120+
"""
121+
Calculate the Merkle root corresponding to a list of items.
122+
"""
123+
return get_root(calc_merkle_tree(items))
124+
125+
126+
def calc_merkle_tree_from_leaves(leaves: Sequence[Hash32]) -> MerkleTree:
127+
if len(leaves) == 0:
128+
raise ValueError("No leaves given")
129+
n_layers = math.log2(len(leaves)) + 1
100130
if not n_layers.is_integer():
101-
raise ValidationError("Item number is not a power of two")
131+
raise ValueError("Number of leaves is not a power of two")
102132
n_layers = int(n_layers)
103-
104-
leaves = tuple(keccak(item) for item in items)
105-
tree = cast(MerkleTree, tuple(take(n_layers, iterate(_hash_layer, leaves)))[::-1])
133+
tree = cast(
134+
MerkleTree,
135+
tuple(
136+
take(
137+
n_layers,
138+
iterate(_hash_layer, leaves),
139+
)
140+
)[::-1]
141+
)
106142
if len(tree[0]) != 1:
107143
raise Exception("Invariant: There must only be one root")
108144

109145
return tree
110146

111147

112-
def calc_merkle_root(items: Sequence[Hashable]) -> Hash32:
113-
"""Calculate the Merkle root corresponding to a list of items."""
114-
return get_root(calc_merkle_tree(items))
148+
def get_merkle_root(leaves: Sequence[Hash32]) -> Hash32:
149+
"""
150+
Return the Merkle root of the given 32-byte hashes.
151+
Note: it has to be a full tree, i.e., `len(values)` is an exact power of 2.
152+
"""
153+
return get_root(calc_merkle_tree_from_leaves(leaves))

eth/beacon/_utils/hash.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
from typing import (
2+
Hashable,
3+
)
4+
15
from eth_typing import Hash32
26
from eth_hash.auto import keccak
37

48

5-
def hash_eth2(data: bytes) -> Hash32:
9+
def hash_eth2(data: Hashable) -> Hash32:
610
"""
711
Return Keccak-256 hashed result.
812
Note: it's a placeholder and we aim to migrate to a S[T/N]ARK-friendly hash function in

tests/core/merkle-utils/test_merkle_trees.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@
44
ValidationError,
55
)
66

7-
from eth_hash.auto import (
8-
keccak,
7+
from eth.beacon._utils.hash import (
8+
hash_eth2,
99
)
1010

1111
from eth._utils.merkle import (
1212
calc_merkle_root,
1313
calc_merkle_tree,
1414
get_root,
1515
get_merkle_proof,
16+
get_merkle_root,
1617
verify_merkle_proof,
1718
)
1819

@@ -21,41 +22,41 @@
2122
(
2223
(b"single leaf",),
2324
(
24-
(keccak(b"single leaf"),),
25+
(hash_eth2(b"single leaf"),),
2526
),
2627
),
2728
(
2829
(b"left", b"right"),
2930
(
30-
(keccak(keccak(b"left") + keccak(b"right")),),
31-
(keccak(b"left"), keccak(b"right")),
31+
(hash_eth2(hash_eth2(b"left") + hash_eth2(b"right")),),
32+
(hash_eth2(b"left"), hash_eth2(b"right")),
3233
),
3334
),
3435
(
3536
(b"1", b"2", b"3", b"4"),
3637
(
3738
(
38-
keccak(
39-
keccak(
40-
keccak(b"1") + keccak(b"2")
41-
) + keccak(
42-
keccak(b"3") + keccak(b"4")
39+
hash_eth2(
40+
hash_eth2(
41+
hash_eth2(b"1") + hash_eth2(b"2")
42+
) + hash_eth2(
43+
hash_eth2(b"3") + hash_eth2(b"4")
4344
)
4445
),
4546
),
4647
(
47-
keccak(
48-
keccak(b"1") + keccak(b"2")
48+
hash_eth2(
49+
hash_eth2(b"1") + hash_eth2(b"2")
4950
),
50-
keccak(
51-
keccak(b"3") + keccak(b"4")
51+
hash_eth2(
52+
hash_eth2(b"3") + hash_eth2(b"4")
5253
),
5354
),
5455
(
55-
keccak(b"1"),
56-
keccak(b"2"),
57-
keccak(b"3"),
58-
keccak(b"4"),
56+
hash_eth2(b"1"),
57+
hash_eth2(b"2"),
58+
hash_eth2(b"3"),
59+
hash_eth2(b"4"),
5960
),
6061
),
6162
),
@@ -69,40 +70,40 @@ def test_merkle_tree_calculation(leaves, tree):
6970

7071
@pytest.mark.parametrize("leave_number", [0, 3, 5, 6, 7, 9])
7172
def test_invalid_merkle_root_calculation(leave_number):
72-
with pytest.raises(ValidationError):
73+
with pytest.raises(ValueError):
7374
calc_merkle_root((b"",) * leave_number)
7475

7576

7677
@pytest.mark.parametrize("leaves,index,proof", [
7778
(
7879
(b"1", b"2"),
7980
0,
80-
(keccak(b"2"),),
81+
(hash_eth2(b"2"),),
8182
),
8283
(
8384
(b"1", b"2"),
8485
1,
85-
(keccak(b"1"),),
86+
(hash_eth2(b"1"),),
8687
),
8788
(
8889
(b"1", b"2", b"3", b"4"),
8990
0,
90-
(keccak(b"2"), keccak(keccak(b"3") + keccak(b"4"))),
91+
(hash_eth2(b"2"), hash_eth2(hash_eth2(b"3") + hash_eth2(b"4"))),
9192
),
9293
(
9394
(b"1", b"2", b"3", b"4"),
9495
1,
95-
(keccak(b"1"), keccak(keccak(b"3") + keccak(b"4"))),
96+
(hash_eth2(b"1"), hash_eth2(hash_eth2(b"3") + hash_eth2(b"4"))),
9697
),
9798
(
9899
(b"1", b"2", b"3", b"4"),
99100
2,
100-
(keccak(b"4"), keccak(keccak(b"1") + keccak(b"2"))),
101+
(hash_eth2(b"4"), hash_eth2(hash_eth2(b"1") + hash_eth2(b"2"))),
101102
),
102103
(
103104
(b"1", b"2", b"3", b"4"),
104105
3,
105-
(keccak(b"3"), keccak(keccak(b"1") + keccak(b"2"))),
106+
(hash_eth2(b"3"), hash_eth2(hash_eth2(b"1") + hash_eth2(b"2"))),
106107
),
107108
])
108109
def test_merkle_proofs(leaves, index, proof):
@@ -142,3 +143,15 @@ def test_proof_generation_index_validation(leaves):
142143
for invalid_index in [-1, len(leaves)]:
143144
with pytest.raises(ValidationError):
144145
get_merkle_proof(tree, invalid_index)
146+
147+
148+
def test_get_merkle_root():
149+
hash_0 = b"0" * 32
150+
leaves = (hash_0,)
151+
root = get_merkle_root(leaves)
152+
assert root == hash_0
153+
154+
hash_1 = b"1" * 32
155+
leaves = (hash_0, hash_1)
156+
root = get_merkle_root(leaves)
157+
assert root == hash_eth2(hash_0 + hash_1)

0 commit comments

Comments
 (0)