7
7
8
8
import math
9
9
from typing import (
10
- cast ,
11
- Hashable ,
10
+ Iterable ,
12
11
NewType ,
13
12
Sequence ,
13
+ Union ,
14
14
)
15
15
16
16
from cytoolz import (
20
20
reduce ,
21
21
take ,
22
22
)
23
- from eth_hash . auto import (
24
- keccak ,
23
+ from eth . beacon . _utils . hash import (
24
+ hash_eth2 ,
25
25
)
26
26
from eth_typing import (
27
27
Hash32 ,
36
36
37
37
38
38
def get_root (tree : MerkleTree ) -> Hash32 :
39
- """Get the root hash of a Merkle tree."""
39
+ """
40
+ Get the root hash of a Merkle tree.
41
+ """
40
42
return tree [0 ][0 ]
41
43
42
44
43
- def get_branch_indices (node_index : int , depth : int ) -> Sequence [int ]:
44
- """Get the indices of all ancestors up until the root for a node with a given depth."""
45
+ def get_branch_indices (node_index : int , depth : int ) -> Iterable [int ]:
46
+ """
47
+ Get the indices of all ancestors up until the root for a node with a given depth.
48
+ """
45
49
return tuple (take (depth , iterate (lambda index : index // 2 , node_index )))
46
50
47
51
48
- def get_merkle_proof (tree : MerkleTree , item_index : int ) -> Sequence [Hash32 ]:
49
- """Read off the Merkle proof for an item from a Merkle tree."""
52
+ def get_merkle_proof (tree : MerkleTree , item_index : int ) -> Iterable [Hash32 ]:
53
+ """
54
+ Read off the Merkle proof for an item from a Merkle tree.
55
+ """
50
56
if item_index < 0 or item_index >= len (tree [- 1 ]):
51
57
raise ValidationError ("Item index out of range" )
52
58
@@ -64,16 +70,20 @@ def get_merkle_proof(tree: MerkleTree, item_index: int) -> Sequence[Hash32]:
64
70
65
71
66
72
def _calc_parent_hash (left_node : Hash32 , right_node : Hash32 ) -> Hash32 :
67
- """Calculate the parent hash of a node and its sibling."""
68
- return keccak (left_node + right_node )
73
+ """
74
+ Calculate the parent hash of a node and its sibling.
75
+ """
76
+ return hash_eth2 (left_node + right_node )
69
77
70
78
71
79
def verify_merkle_proof (root : Hash32 ,
72
- item : Hashable ,
80
+ item : Union [ bytes , bytearray ] ,
73
81
item_index : int ,
74
82
proof : MerkleProof ) -> bool :
75
- """Verify a Merkle proof against a root hash."""
76
- leaf = keccak (item )
83
+ """
84
+ Verify a Merkle proof against a root hash.
85
+ """
86
+ leaf = hash_eth2 (item )
77
87
branch_indices = get_branch_indices (item_index , len (proof ))
78
88
node_orderers = [
79
89
identity if branch_index % 2 == 0 else reversed
@@ -87,28 +97,51 @@ def verify_merkle_proof(root: Hash32,
87
97
return proof_root == root
88
98
89
99
90
- def _hash_layer (layer : Sequence [Hash32 ]) -> Sequence [Hash32 ]:
91
- """Calculate the layer on top of another one."""
92
- return tuple (_calc_parent_hash (left , right ) for left , right in partition (2 , layer ))
100
+ def _hash_layer (layer : Sequence [Hash32 ]) -> Iterable [Hash32 ]:
101
+ """
102
+ Calculate the layer on top of another one.
103
+ """
104
+ return tuple (
105
+ _calc_parent_hash (left , right )
106
+ for left , right in partition (2 , layer )
107
+ )
108
+
93
109
110
+ def calc_merkle_tree (items : Sequence [Union [bytes , bytearray ]]) -> MerkleTree :
111
+ """
112
+ Calculate the Merkle tree corresponding to a list of items.
113
+ """
114
+ leaves = tuple (hash_eth2 (item ) for item in items )
115
+ return calc_merkle_tree_from_leaves (leaves )
94
116
95
- def calc_merkle_tree (items : Sequence [Hashable ]) -> MerkleTree :
96
- """Calculate the Merkle tree corresponding to a list of items."""
97
- if len (items ) == 0 :
98
- raise ValidationError ("No items given" )
99
- n_layers = math .log2 (len (items )) + 1
117
+
118
+ def get_merkle_root_from_items (items : Sequence [Union [bytes , bytearray ]]) -> Hash32 :
119
+ """
120
+ Calculate the Merkle root corresponding to a list of items.
121
+ """
122
+ return get_root (calc_merkle_tree (items ))
123
+
124
+
125
+ def calc_merkle_tree_from_leaves (leaves : Sequence [Hash32 ]) -> MerkleTree :
126
+ if len (leaves ) == 0 :
127
+ raise ValueError ("No leaves given" )
128
+ n_layers = math .log2 (len (leaves )) + 1
100
129
if not n_layers .is_integer ():
101
- raise ValidationError ( "Item number is not a power of two" )
130
+ raise ValueError ( "Number of leaves is not a power of two" )
102
131
n_layers = int (n_layers )
103
132
104
- leaves = tuple (keccak (item ) for item in items )
105
- tree = cast (MerkleTree , tuple (take (n_layers , iterate (_hash_layer , leaves )))[::- 1 ])
133
+ reversed_tree = tuple (take (n_layers , iterate (_hash_layer , leaves )))
134
+ tree = MerkleTree (tuple (reversed (reversed_tree )))
135
+
106
136
if len (tree [0 ]) != 1 :
107
137
raise Exception ("Invariant: There must only be one root" )
108
138
109
139
return tree
110
140
111
141
112
- def calc_merkle_root (items : Sequence [Hashable ]) -> Hash32 :
113
- """Calculate the Merkle root corresponding to a list of items."""
114
- return get_root (calc_merkle_tree (items ))
142
+ def get_merkle_root (leaves : Sequence [Hash32 ]) -> Hash32 :
143
+ """
144
+ Return the Merkle root of the given 32-byte hashes.
145
+ Note: it has to be a full tree, i.e., `len(values)` is an exact power of 2.
146
+ """
147
+ return get_root (calc_merkle_tree_from_leaves (leaves ))
0 commit comments