Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions trie/bintrie/binary_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,18 @@ func SerializeNode(node BinaryNode) []byte {

var invalidSerializedLength = errors.New("invalid serialized node length")

// DeserializeNode deserializes a binary trie node from a byte slice.
// DeserializeNode deserializes a binary trie node from a byte slice. The
// hash will be recomputed from the deserialized data.
func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) {
return deserializeNode(serialized, depth, common.Hash{}, true)
}

// DeserializeNodeWithHash deserializes a binary trie node from a byte slice, using the provided hash.
func DeserializeNodeWithHash(serialized []byte, depth int, hn common.Hash) (BinaryNode, error) {
return deserializeNode(serialized, depth, hn, false)
}

func deserializeNode(serialized []byte, depth int, hn common.Hash, mustRecompute bool) (BinaryNode, error) {
if len(serialized) == 0 {
return Empty{}, nil
}
Expand All @@ -102,9 +112,11 @@ func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) {
return nil, invalidSerializedLength
}
return &InternalNode{
depth: depth,
left: HashedNode(common.BytesToHash(serialized[1:33])),
right: HashedNode(common.BytesToHash(serialized[33:65])),
depth: depth,
left: HashedNode(common.BytesToHash(serialized[1:33])),
right: HashedNode(common.BytesToHash(serialized[33:65])),
hash: hn,
mustRecompute: mustRecompute,
}, nil
case nodeTypeStem:
if len(serialized) < 64 {
Expand All @@ -124,9 +136,11 @@ func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) {
}
}
return &StemNode{
Stem: serialized[NodeTypeBytes : NodeTypeBytes+StemSize],
Values: values[:],
depth: depth,
Stem: serialized[NodeTypeBytes : NodeTypeBytes+StemSize],
Values: values[:],
depth: depth,
hash: hn,
mustRecompute: mustRecompute,
}, nil
default:
return nil, errors.New("invalid node type")
Expand Down
14 changes: 8 additions & 6 deletions trie/bintrie/empty.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ func (e Empty) Insert(key []byte, value []byte, _ NodeResolverFn, depth int) (Bi
var values [256][]byte
values[key[31]] = value
return &StemNode{
Stem: slices.Clone(key[:31]),
Values: values[:],
depth: depth,
Stem: slices.Clone(key[:31]),
Values: values[:],
depth: depth,
mustRecompute: true,
}, nil
}

Expand All @@ -53,9 +54,10 @@ func (e Empty) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) {

func (e Empty) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
return &StemNode{
Stem: slices.Clone(key[:31]),
Values: values,
depth: depth,
Stem: slices.Clone(key[:31]),
Values: values,
depth: depth,
mustRecompute: true,
}, nil
}

Expand Down
2 changes: 1 addition & 1 deletion trie/bintrie/hashed_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (h HashedNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver No
}

// Step 3: Deserialize the resolved data into a concrete node
node, err := DeserializeNode(data, depth)
node, err := DeserializeNodeWithHash(data, depth, common.Hash(h))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}
Expand Down
29 changes: 21 additions & 8 deletions trie/bintrie/internal_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ func keyToPath(depth int, key []byte) ([]byte, error) {
type InternalNode struct {
left, right BinaryNode
depth int

mustRecompute bool // true if the hash needs to be recomputed
hash common.Hash // cached hash when mustRecompute == false
}

// GetValuesAtStem retrieves the group of values located at the given stem key.
Expand All @@ -59,7 +62,7 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1)
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
}
Expand All @@ -77,7 +80,7 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1)
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
}
Expand Down Expand Up @@ -108,14 +111,20 @@ func (bt *InternalNode) Insert(key []byte, value []byte, resolver NodeResolverFn
// Copy creates a deep copy of the node.
func (bt *InternalNode) Copy() BinaryNode {
return &InternalNode{
left: bt.left.Copy(),
right: bt.right.Copy(),
depth: bt.depth,
left: bt.left.Copy(),
right: bt.right.Copy(),
depth: bt.depth,
mustRecompute: bt.mustRecompute,
hash: bt.hash,
}
}

// Hash returns the hash of the node.
func (bt *InternalNode) Hash() common.Hash {
if !bt.mustRecompute {
return bt.hash
}

h := sha256.New()
if bt.left != nil {
h.Write(bt.left.Hash().Bytes())
Expand All @@ -127,7 +136,9 @@ func (bt *InternalNode) Hash() common.Hash {
} else {
h.Write(zero[:])
}
return common.BytesToHash(h.Sum(nil))
bt.hash = common.BytesToHash(h.Sum(nil))
bt.mustRecompute = false
return bt.hash
}

// InsertValuesAtStem inserts a full value group at the given stem in the internal node.
Expand All @@ -149,14 +160,15 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1)
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}
bt.left = node
}

bt.left, err = bt.left.InsertValuesAtStem(stem, values, resolver, depth+1)
bt.mustRecompute = true
return bt, err
}

Expand All @@ -173,14 +185,15 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1)
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}
bt.right = node
}

bt.right, err = bt.right.InsertValuesAtStem(stem, values, resolver, depth+1)
bt.mustRecompute = true
return bt, err
}

Expand Down
8 changes: 5 additions & 3 deletions trie/bintrie/internal_node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,16 +239,18 @@ func TestInternalNodeHash(t *testing.T) {

// Changing a child should change the hash
node.left = HashedNode(common.HexToHash("0x3333"))
node.mustRecompute = true
hash3 := node.Hash()
if hash1 == hash3 {
t.Error("Hash didn't change after modifying left child")
}

// Test with nil children (should use zero hash)
nodeWithNil := &InternalNode{
depth: 0,
left: nil,
right: HashedNode(common.HexToHash("0x4444")),
depth: 0,
left: nil,
right: HashedNode(common.HexToHash("0x4444")),
mustRecompute: true,
}
hashWithNil := nodeWithNil.Hash()
if hashWithNil == (common.Hash{}) {
Expand Down
2 changes: 1 addition & 1 deletion trie/bintrie/iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func (it *binaryNodeIterator) Next(descend bool) bool {
if err != nil {
panic(err)
}
it.current, err = DeserializeNode(data, len(it.stack)-1)
it.current, err = DeserializeNodeWithHash(data, len(it.stack)-1, common.Hash(node))
if err != nil {
panic(err)
}
Expand Down
39 changes: 27 additions & 12 deletions trie/bintrie/stem_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ type StemNode struct {
Stem []byte // Stem path to get to StemNodeWidth values
Values [][]byte // All values, indexed by the last byte of the key.
depth int // Depth of the node

mustRecompute bool // true if the hash needs to be recomputed
hash common.Hash // cached hash when mustRecompute == false
}

// Get retrieves the value for the given key.
Expand All @@ -43,7 +46,7 @@ func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int
if !bytes.Equal(bt.Stem, key[:StemSize]) {
bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1

n := &InternalNode{depth: bt.depth}
n := &InternalNode{depth: bt.depth, mustRecompute: true}
bt.depth++
var child, other *BinaryNode
if bitStem == 0 {
Expand All @@ -68,9 +71,10 @@ func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int
var values [StemNodeWidth][]byte
values[key[StemSize]] = value
*other = &StemNode{
Stem: slices.Clone(key[:StemSize]),
Values: values[:],
depth: depth + 1,
Stem: slices.Clone(key[:StemSize]),
Values: values[:],
depth: depth + 1,
mustRecompute: true,
}
}
return n, nil
Expand All @@ -79,6 +83,7 @@ func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int
return bt, errors.New("invalid insertion: value length")
}
bt.Values[key[StemSize]] = value
bt.mustRecompute = true
return bt, nil
}

Expand All @@ -89,9 +94,11 @@ func (bt *StemNode) Copy() BinaryNode {
values[i] = slices.Clone(v)
}
return &StemNode{
Stem: slices.Clone(bt.Stem),
Values: values[:],
depth: bt.depth,
Stem: slices.Clone(bt.Stem),
Values: values[:],
depth: bt.depth,
hash: bt.hash,
mustRecompute: bt.mustRecompute,
}
}

Expand All @@ -102,6 +109,10 @@ func (bt *StemNode) GetHeight() int {

// Hash returns the hash of the node.
func (bt *StemNode) Hash() common.Hash {
if !bt.mustRecompute {
return bt.hash
}

var data [StemNodeWidth]common.Hash
for i, v := range bt.Values {
if v != nil {
Expand Down Expand Up @@ -130,7 +141,9 @@ func (bt *StemNode) Hash() common.Hash {
h.Write(bt.Stem)
h.Write([]byte{0})
h.Write(data[0][:])
return common.BytesToHash(h.Sum(nil))
bt.hash = common.BytesToHash(h.Sum(nil))
bt.mustRecompute = false
return bt.hash
}

// CollectNodes collects all child nodes at a given path, and flushes it
Expand All @@ -154,7 +167,7 @@ func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolv
if !bytes.Equal(bt.Stem, key[:StemSize]) {
bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1

n := &InternalNode{depth: bt.depth}
n := &InternalNode{depth: bt.depth, mustRecompute: true}
bt.depth++
var child, other *BinaryNode
if bitStem == 0 {
Expand All @@ -177,9 +190,10 @@ func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolv
*other = Empty{}
} else {
*other = &StemNode{
Stem: slices.Clone(key[:StemSize]),
Values: values,
depth: n.depth + 1,
Stem: slices.Clone(key[:StemSize]),
Values: values,
depth: n.depth + 1,
mustRecompute: true,
}
}
return n, nil
Expand All @@ -189,6 +203,7 @@ func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolv
for i, v := range values {
if v != nil {
bt.Values[i] = v
bt.mustRecompute = true
}
}
return bt, nil
Expand Down
1 change: 1 addition & 0 deletions trie/bintrie/stem_node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ func TestStemNodeHash(t *testing.T) {

// Changing a value should change the hash
node.Values[1] = common.HexToHash("0x0202").Bytes()
node.mustRecompute = true
hash3 := node.Hash()
if hash1 == hash3 {
t.Error("Hash didn't change after modifying values")
Expand Down
2 changes: 1 addition & 1 deletion trie/bintrie/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, err
if err != nil {
return nil, err
}
node, err := DeserializeNode(blob, 0)
node, err := DeserializeNodeWithHash(blob, 0, root)
if err != nil {
return nil, err
}
Expand Down