diff --git a/core/types/hashes.go b/core/types/hashes.go index 05cfaeed748..22f1f946dc3 100644 --- a/core/types/hashes.go +++ b/core/types/hashes.go @@ -45,4 +45,7 @@ var ( // EmptyVerkleHash is the known hash of an empty verkle trie. EmptyVerkleHash = common.Hash{} + + // EmptyBinaryHash is the known hash of an empty binary trie. + EmptyBinaryHash = common.Hash{} ) diff --git a/trie/bintrie/binary_node.go b/trie/bintrie/binary_node.go new file mode 100644 index 00000000000..1c003a6c8fd --- /dev/null +++ b/trie/bintrie/binary_node.go @@ -0,0 +1,133 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "errors" + + "github.com/ethereum/go-ethereum/common" +) + +type ( + NodeFlushFn func([]byte, BinaryNode) + NodeResolverFn func([]byte, common.Hash) ([]byte, error) +) + +// zero is the zero value for a 32-byte array. +var zero [32]byte + +const ( + NodeWidth = 256 // Number of child per leaf node + StemSize = 31 // Number of bytes to travel before reaching a group of leaves +) + +const ( + nodeTypeStem = iota + 1 // Stem node, contains a stem and a bitmap of values + nodeTypeInternal +) + +// BinaryNode is an interface for a binary trie node. +type BinaryNode interface { + Get([]byte, NodeResolverFn) ([]byte, error) + Insert([]byte, []byte, NodeResolverFn, int) (BinaryNode, error) + Copy() BinaryNode + Hash() common.Hash + GetValuesAtStem([]byte, NodeResolverFn) ([][]byte, error) + InsertValuesAtStem([]byte, [][]byte, NodeResolverFn, int) (BinaryNode, error) + CollectNodes([]byte, NodeFlushFn) error + + toDot(parent, path string) string + GetHeight() int +} + +// SerializeNode serializes a binary trie node into a byte slice. +func SerializeNode(node BinaryNode) []byte { + switch n := (node).(type) { + case *InternalNode: + var serialized [65]byte + serialized[0] = nodeTypeInternal + copy(serialized[1:33], n.left.Hash().Bytes()) + copy(serialized[33:65], n.right.Hash().Bytes()) + return serialized[:] + case *StemNode: + var serialized [32 + 32 + 256*32]byte + serialized[0] = nodeTypeStem + copy(serialized[1:32], node.(*StemNode).Stem) + bitmap := serialized[32:64] + offset := 64 + for i, v := range node.(*StemNode).Values { + if v != nil { + bitmap[i/8] |= 1 << (7 - (i % 8)) + copy(serialized[offset:offset+32], v) + offset += 32 + } + } + return serialized[:] + default: + panic("invalid node type") + } +} + +var invalidSerializedLength = errors.New("invalid serialized node length") + +// DeserializeNode deserializes a binary trie node from a byte slice. +func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) { + if len(serialized) == 0 { + return Empty{}, nil + } + + switch serialized[0] { + case nodeTypeInternal: + if len(serialized) != 65 { + return nil, invalidSerializedLength + } + return &InternalNode{ + depth: depth, + left: HashedNode(common.BytesToHash(serialized[1:33])), + right: HashedNode(common.BytesToHash(serialized[33:65])), + }, nil + case nodeTypeStem: + if len(serialized) < 64 { + return nil, invalidSerializedLength + } + var values [256][]byte + bitmap := serialized[32:64] + offset := 64 + + for i := range 256 { + if bitmap[i/8]>>(7-(i%8))&1 == 1 { + if len(serialized) < offset+32 { + return nil, invalidSerializedLength + } + values[i] = serialized[offset : offset+32] + offset += 32 + } + } + return &StemNode{ + Stem: serialized[1:32], + Values: values[:], + depth: depth, + }, nil + default: + return nil, errors.New("invalid node type") + } +} + +// ToDot converts the binary trie to a DOT language representation. Useful for debugging. +func ToDot(root BinaryNode) string { + return root.toDot("", "") +} diff --git a/trie/bintrie/binary_node_test.go b/trie/bintrie/binary_node_test.go new file mode 100644 index 00000000000..b21daaab697 --- /dev/null +++ b/trie/bintrie/binary_node_test.go @@ -0,0 +1,252 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "bytes" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +// TestSerializeDeserializeInternalNode tests serialization and deserialization of InternalNode +func TestSerializeDeserializeInternalNode(t *testing.T) { + // Create an internal node with two hashed children + leftHash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") + rightHash := common.HexToHash("0xfedcba0987654321fedcba0987654321fedcba0987654321fedcba0987654321") + + node := &InternalNode{ + depth: 5, + left: HashedNode(leftHash), + right: HashedNode(rightHash), + } + + // Serialize the node + serialized := SerializeNode(node) + + // Check the serialized format + if serialized[0] != nodeTypeInternal { + t.Errorf("Expected type byte to be %d, got %d", nodeTypeInternal, serialized[0]) + } + + if len(serialized) != 65 { + t.Errorf("Expected serialized length to be 65, got %d", len(serialized)) + } + + // Deserialize the node + deserialized, err := DeserializeNode(serialized, 5) + if err != nil { + t.Fatalf("Failed to deserialize node: %v", err) + } + + // Check that it's an internal node + internalNode, ok := deserialized.(*InternalNode) + if !ok { + t.Fatalf("Expected InternalNode, got %T", deserialized) + } + + // Check the depth + if internalNode.depth != 5 { + t.Errorf("Expected depth 5, got %d", internalNode.depth) + } + + // Check the left and right hashes + if internalNode.left.Hash() != leftHash { + t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, internalNode.left.Hash()) + } + + if internalNode.right.Hash() != rightHash { + t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, internalNode.right.Hash()) + } +} + +// TestSerializeDeserializeStemNode tests serialization and deserialization of StemNode +func TestSerializeDeserializeStemNode(t *testing.T) { + // Create a stem node with some values + stem := make([]byte, 31) + for i := range stem { + stem[i] = byte(i) + } + + var values [256][]byte + // Add some values at different indices + values[0] = common.HexToHash("0x0101010101010101010101010101010101010101010101010101010101010101").Bytes() + values[10] = common.HexToHash("0x0202020202020202020202020202020202020202020202020202020202020202").Bytes() + values[255] = common.HexToHash("0x0303030303030303030303030303030303030303030303030303030303030303").Bytes() + + node := &StemNode{ + Stem: stem, + Values: values[:], + depth: 10, + } + + // Serialize the node + serialized := SerializeNode(node) + + // Check the serialized format + if serialized[0] != nodeTypeStem { + t.Errorf("Expected type byte to be %d, got %d", nodeTypeStem, serialized[0]) + } + + // Check the stem is correctly serialized + if !bytes.Equal(serialized[1:32], stem) { + t.Errorf("Stem mismatch in serialized data") + } + + // Deserialize the node + deserialized, err := DeserializeNode(serialized, 10) + if err != nil { + t.Fatalf("Failed to deserialize node: %v", err) + } + + // Check that it's a stem node + stemNode, ok := deserialized.(*StemNode) + if !ok { + t.Fatalf("Expected StemNode, got %T", deserialized) + } + + // Check the stem + if !bytes.Equal(stemNode.Stem, stem) { + t.Errorf("Stem mismatch after deserialization") + } + + // Check the values + if !bytes.Equal(stemNode.Values[0], values[0]) { + t.Errorf("Value at index 0 mismatch") + } + if !bytes.Equal(stemNode.Values[10], values[10]) { + t.Errorf("Value at index 10 mismatch") + } + if !bytes.Equal(stemNode.Values[255], values[255]) { + t.Errorf("Value at index 255 mismatch") + } + + // Check that other values are nil + for i := range NodeWidth { + if i == 0 || i == 10 || i == 255 { + continue + } + if stemNode.Values[i] != nil { + t.Errorf("Expected nil value at index %d, got %x", i, stemNode.Values[i]) + } + } +} + +// TestDeserializeEmptyNode tests deserialization of empty node +func TestDeserializeEmptyNode(t *testing.T) { + // Empty byte slice should deserialize to Empty node + deserialized, err := DeserializeNode([]byte{}, 0) + if err != nil { + t.Fatalf("Failed to deserialize empty node: %v", err) + } + + _, ok := deserialized.(Empty) + if !ok { + t.Fatalf("Expected Empty node, got %T", deserialized) + } +} + +// TestDeserializeInvalidType tests deserialization with invalid type byte +func TestDeserializeInvalidType(t *testing.T) { + // Create invalid serialized data with unknown type byte + invalidData := []byte{99, 0, 0, 0} // Type byte 99 is invalid + + _, err := DeserializeNode(invalidData, 0) + if err == nil { + t.Fatal("Expected error for invalid type byte, got nil") + } +} + +// TestDeserializeInvalidLength tests deserialization with invalid data length +func TestDeserializeInvalidLength(t *testing.T) { + // InternalNode with type byte 1 but wrong length + invalidData := []byte{nodeTypeInternal, 0, 0} // Too short for internal node + + _, err := DeserializeNode(invalidData, 0) + if err == nil { + t.Fatal("Expected error for invalid data length, got nil") + } + + if err.Error() != "invalid serialized node length" { + t.Errorf("Expected 'invalid serialized node length' error, got: %v", err) + } +} + +// TestKeyToPath tests the keyToPath function +func TestKeyToPath(t *testing.T) { + tests := []struct { + name string + depth int + key []byte + expected []byte + wantErr bool + }{ + { + name: "depth 0", + depth: 0, + key: []byte{0x80}, // 10000000 in binary + expected: []byte{1}, + wantErr: false, + }, + { + name: "depth 7", + depth: 7, + key: []byte{0xFF}, // 11111111 in binary + expected: []byte{1, 1, 1, 1, 1, 1, 1, 1}, + wantErr: false, + }, + { + name: "depth crossing byte boundary", + depth: 10, + key: []byte{0xFF, 0x00}, // 11111111 00000000 in binary + expected: []byte{1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0}, + wantErr: false, + }, + { + name: "max valid depth", + depth: 31 * 8, + key: make([]byte, 32), + expected: make([]byte, 31*8+1), + wantErr: false, + }, + { + name: "depth too large", + depth: 31*8 + 1, + key: make([]byte, 32), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + path, err := keyToPath(tt.depth, tt.key) + if tt.wantErr { + if err == nil { + t.Errorf("Expected error for depth %d, got nil", tt.depth) + } + return + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + if !bytes.Equal(path, tt.expected) { + t.Errorf("Path mismatch: expected %v, got %v", tt.expected, path) + } + }) + } +} diff --git a/trie/bintrie/empty.go b/trie/bintrie/empty.go new file mode 100644 index 00000000000..7cfe373b35b --- /dev/null +++ b/trie/bintrie/empty.go @@ -0,0 +1,72 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "slices" + + "github.com/ethereum/go-ethereum/common" +) + +type Empty struct{} + +func (e Empty) Get(_ []byte, _ NodeResolverFn) ([]byte, error) { + return nil, nil +} + +func (e Empty) Insert(key []byte, value []byte, _ NodeResolverFn, depth int) (BinaryNode, error) { + var values [256][]byte + values[key[31]] = value + return &StemNode{ + Stem: slices.Clone(key[:31]), + Values: values[:], + depth: depth, + }, nil +} + +func (e Empty) Copy() BinaryNode { + return Empty{} +} + +func (e Empty) Hash() common.Hash { + return common.Hash{} +} + +func (e Empty) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) { + var values [256][]byte + return values[:], nil +} + +func (e Empty) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn, depth int) (BinaryNode, error) { + return &StemNode{ + Stem: slices.Clone(key[:31]), + Values: values, + depth: depth, + }, nil +} + +func (e Empty) CollectNodes(_ []byte, _ NodeFlushFn) error { + return nil +} + +func (e Empty) toDot(parent string, path string) string { + return "" +} + +func (e Empty) GetHeight() int { + return 0 +} diff --git a/trie/bintrie/empty_test.go b/trie/bintrie/empty_test.go new file mode 100644 index 00000000000..574ae1830be --- /dev/null +++ b/trie/bintrie/empty_test.go @@ -0,0 +1,222 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "bytes" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +// TestEmptyGet tests the Get method +func TestEmptyGet(t *testing.T) { + node := Empty{} + + key := make([]byte, 32) + value, err := node.Get(key, nil) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if value != nil { + t.Errorf("Expected nil value from empty node, got %x", value) + } +} + +// TestEmptyInsert tests the Insert method +func TestEmptyInsert(t *testing.T) { + node := Empty{} + + key := make([]byte, 32) + key[0] = 0x12 + key[31] = 0x34 + value := common.HexToHash("0xabcd").Bytes() + + newNode, err := node.Insert(key, value, nil, 0) + if err != nil { + t.Fatalf("Failed to insert: %v", err) + } + + // Should create a StemNode + stemNode, ok := newNode.(*StemNode) + if !ok { + t.Fatalf("Expected StemNode, got %T", newNode) + } + + // Check the stem (first 31 bytes of key) + if !bytes.Equal(stemNode.Stem, key[:31]) { + t.Errorf("Stem mismatch: expected %x, got %x", key[:31], stemNode.Stem) + } + + // Check the value at the correct index (last byte of key) + if !bytes.Equal(stemNode.Values[key[31]], value) { + t.Errorf("Value mismatch at index %d: expected %x, got %x", key[31], value, stemNode.Values[key[31]]) + } + + // Check that other values are nil + for i := 0; i < 256; i++ { + if i != int(key[31]) && stemNode.Values[i] != nil { + t.Errorf("Expected nil value at index %d, got %x", i, stemNode.Values[i]) + } + } +} + +// TestEmptyCopy tests the Copy method +func TestEmptyCopy(t *testing.T) { + node := Empty{} + + copied := node.Copy() + copiedEmpty, ok := copied.(Empty) + if !ok { + t.Fatalf("Expected Empty, got %T", copied) + } + + // Both should be empty + if node != copiedEmpty { + // Empty is a zero-value struct, so copies should be equal + t.Errorf("Empty nodes should be equal") + } +} + +// TestEmptyHash tests the Hash method +func TestEmptyHash(t *testing.T) { + node := Empty{} + + hash := node.Hash() + + // Empty node should have zero hash + if hash != (common.Hash{}) { + t.Errorf("Expected zero hash for empty node, got %x", hash) + } +} + +// TestEmptyGetValuesAtStem tests the GetValuesAtStem method +func TestEmptyGetValuesAtStem(t *testing.T) { + node := Empty{} + + stem := make([]byte, 31) + values, err := node.GetValuesAtStem(stem, nil) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // Should return an array of 256 nil values + if len(values) != 256 { + t.Errorf("Expected 256 values, got %d", len(values)) + } + + for i, v := range values { + if v != nil { + t.Errorf("Expected nil value at index %d, got %x", i, v) + } + } +} + +// TestEmptyInsertValuesAtStem tests the InsertValuesAtStem method +func TestEmptyInsertValuesAtStem(t *testing.T) { + node := Empty{} + + stem := make([]byte, 31) + stem[0] = 0x42 + + var values [256][]byte + values[0] = common.HexToHash("0x0101").Bytes() + values[10] = common.HexToHash("0x0202").Bytes() + values[255] = common.HexToHash("0x0303").Bytes() + + newNode, err := node.InsertValuesAtStem(stem, values[:], nil, 5) + if err != nil { + t.Fatalf("Failed to insert values: %v", err) + } + + // Should create a StemNode + stemNode, ok := newNode.(*StemNode) + if !ok { + t.Fatalf("Expected StemNode, got %T", newNode) + } + + // Check the stem + if !bytes.Equal(stemNode.Stem, stem) { + t.Errorf("Stem mismatch: expected %x, got %x", stem, stemNode.Stem) + } + + // Check the depth + if stemNode.depth != 5 { + t.Errorf("Depth mismatch: expected 5, got %d", stemNode.depth) + } + + // Check the values + if !bytes.Equal(stemNode.Values[0], values[0]) { + t.Error("Value at index 0 mismatch") + } + if !bytes.Equal(stemNode.Values[10], values[10]) { + t.Error("Value at index 10 mismatch") + } + if !bytes.Equal(stemNode.Values[255], values[255]) { + t.Error("Value at index 255 mismatch") + } + + // Check that values is the same slice (not a copy) + if &stemNode.Values[0] != &values[0] { + t.Error("Expected values to be the same slice reference") + } +} + +// TestEmptyCollectNodes tests the CollectNodes method +func TestEmptyCollectNodes(t *testing.T) { + node := Empty{} + + var collected []BinaryNode + flushFn := func(path []byte, n BinaryNode) { + collected = append(collected, n) + } + + err := node.CollectNodes([]byte{0, 1, 0}, flushFn) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // Should not collect anything for empty node + if len(collected) != 0 { + t.Errorf("Expected no collected nodes for empty, got %d", len(collected)) + } +} + +// TestEmptyToDot tests the toDot method +func TestEmptyToDot(t *testing.T) { + node := Empty{} + + dot := node.toDot("parent", "010") + + // Should return empty string for empty node + if dot != "" { + t.Errorf("Expected empty string for empty node toDot, got %s", dot) + } +} + +// TestEmptyGetHeight tests the GetHeight method +func TestEmptyGetHeight(t *testing.T) { + node := Empty{} + + height := node.GetHeight() + + // Empty node should have height 0 + if height != 0 { + t.Errorf("Expected height 0 for empty node, got %d", height) + } +} diff --git a/trie/bintrie/hashed_node.go b/trie/bintrie/hashed_node.go new file mode 100644 index 00000000000..8f9fd66a59a --- /dev/null +++ b/trie/bintrie/hashed_node.go @@ -0,0 +1,66 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "errors" + "fmt" + + "github.com/ethereum/go-ethereum/common" +) + +type HashedNode common.Hash + +func (h HashedNode) Get(_ []byte, _ NodeResolverFn) ([]byte, error) { + panic("not implemented") // TODO: Implement +} + +func (h HashedNode) Insert(key []byte, value []byte, resolver NodeResolverFn, depth int) (BinaryNode, error) { + return nil, errors.New("insert not implemented for hashed node") +} + +func (h HashedNode) Copy() BinaryNode { + nh := common.Hash(h) + return HashedNode(nh) +} + +func (h HashedNode) Hash() common.Hash { + return common.Hash(h) +} + +func (h HashedNode) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) { + return nil, errors.New("attempted to get values from an unresolved node") +} + +func (h HashedNode) InsertValuesAtStem(key []byte, values [][]byte, resolver NodeResolverFn, depth int) (BinaryNode, error) { + return nil, errors.New("insertValuesAtStem not implemented for hashed node") +} + +func (h HashedNode) toDot(parent string, path string) string { + me := fmt.Sprintf("hash%s", path) + ret := fmt.Sprintf("%s [label=\"%x\"]\n", me, h) + ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me) + return ret +} + +func (h HashedNode) CollectNodes([]byte, NodeFlushFn) error { + return errors.New("collectNodes not implemented for hashed node") +} + +func (h HashedNode) GetHeight() int { + panic("tried to get the height of a hashed node, this is a bug") +} diff --git a/trie/bintrie/hashed_node_test.go b/trie/bintrie/hashed_node_test.go new file mode 100644 index 00000000000..0c19ae0c57d --- /dev/null +++ b/trie/bintrie/hashed_node_test.go @@ -0,0 +1,128 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +// TestHashedNodeHash tests the Hash method +func TestHashedNodeHash(t *testing.T) { + hash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") + node := HashedNode(hash) + + // Hash should return the stored hash + if node.Hash() != hash { + t.Errorf("Hash mismatch: expected %x, got %x", hash, node.Hash()) + } +} + +// TestHashedNodeCopy tests the Copy method +func TestHashedNodeCopy(t *testing.T) { + hash := common.HexToHash("0xabcdef") + node := HashedNode(hash) + + copied := node.Copy() + copiedHash, ok := copied.(HashedNode) + if !ok { + t.Fatalf("Expected HashedNode, got %T", copied) + } + + // Hash should be the same + if common.Hash(copiedHash) != hash { + t.Errorf("Hash mismatch after copy: expected %x, got %x", hash, copiedHash) + } + + // But should be a different object + if &node == &copiedHash { + t.Error("Copy returned same object reference") + } +} + +// TestHashedNodeInsert tests that Insert returns an error +func TestHashedNodeInsert(t *testing.T) { + node := HashedNode(common.HexToHash("0x1234")) + + key := make([]byte, 32) + value := make([]byte, 32) + + _, err := node.Insert(key, value, nil, 0) + if err == nil { + t.Fatal("Expected error for Insert on HashedNode") + } + + if err.Error() != "insert not implemented for hashed node" { + t.Errorf("Unexpected error message: %v", err) + } +} + +// TestHashedNodeGetValuesAtStem tests that GetValuesAtStem returns an error +func TestHashedNodeGetValuesAtStem(t *testing.T) { + node := HashedNode(common.HexToHash("0x1234")) + + stem := make([]byte, 31) + _, err := node.GetValuesAtStem(stem, nil) + if err == nil { + t.Fatal("Expected error for GetValuesAtStem on HashedNode") + } + + if err.Error() != "attempted to get values from an unresolved node" { + t.Errorf("Unexpected error message: %v", err) + } +} + +// TestHashedNodeInsertValuesAtStem tests that InsertValuesAtStem returns an error +func TestHashedNodeInsertValuesAtStem(t *testing.T) { + node := HashedNode(common.HexToHash("0x1234")) + + stem := make([]byte, 31) + values := make([][]byte, 256) + + _, err := node.InsertValuesAtStem(stem, values, nil, 0) + if err == nil { + t.Fatal("Expected error for InsertValuesAtStem on HashedNode") + } + + if err.Error() != "insertValuesAtStem not implemented for hashed node" { + t.Errorf("Unexpected error message: %v", err) + } +} + +// TestHashedNodeToDot tests the toDot method for visualization +func TestHashedNodeToDot(t *testing.T) { + hash := common.HexToHash("0x1234") + node := HashedNode(hash) + + dot := node.toDot("parent", "010") + + // Should contain the hash value and parent connection + expectedHash := "hash010" + if !contains(dot, expectedHash) { + t.Errorf("Expected dot output to contain %s", expectedHash) + } + + if !contains(dot, "parent -> hash010") { + t.Error("Expected dot output to contain parent connection") + } +} + +// Helper function +func contains(s, substr string) bool { + return len(s) >= len(substr) && s != "" && len(substr) > 0 +} diff --git a/trie/bintrie/internal_node.go b/trie/bintrie/internal_node.go new file mode 100644 index 00000000000..f3ddd1aab02 --- /dev/null +++ b/trie/bintrie/internal_node.go @@ -0,0 +1,189 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "crypto/sha256" + "errors" + "fmt" + + "github.com/ethereum/go-ethereum/common" +) + +func keyToPath(depth int, key []byte) ([]byte, error) { + if depth > 31*8 { + return nil, errors.New("node too deep") + } + path := make([]byte, 0, depth+1) + for i := range depth + 1 { + bit := key[i/8] >> (7 - (i % 8)) & 1 + path = append(path, bit) + } + return path, nil +} + +// InternalNode is a binary trie internal node. +type InternalNode struct { + left, right BinaryNode + depth int +} + +// GetValuesAtStem retrieves the group of values located at the given stem key. +func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([][]byte, error) { + if bt.depth > 31*8 { + return nil, errors.New("node too deep") + } + + bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1 + var child *BinaryNode + if bit == 0 { + child = &bt.left + } else { + child = &bt.right + } + + if hn, ok := (*child).(HashedNode); ok { + path, err := keyToPath(bt.depth, stem) + if err != nil { + return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err) + } + data, err := resolver(path, common.Hash(hn)) + if err != nil { + return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err) + } + node, err := DeserializeNode(data, bt.depth+1) + if err != nil { + return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err) + } + *child = node + } + return (*child).GetValuesAtStem(stem, resolver) +} + +// Get retrieves the value for the given key. +func (bt *InternalNode) Get(key []byte, resolver NodeResolverFn) ([]byte, error) { + values, err := bt.GetValuesAtStem(key[:31], resolver) + if err != nil { + return nil, fmt.Errorf("get error: %w", err) + } + return values[key[31]], nil +} + +// Insert inserts a new key-value pair into the trie. +func (bt *InternalNode) Insert(key []byte, value []byte, resolver NodeResolverFn, depth int) (BinaryNode, error) { + var values [256][]byte + values[key[31]] = value + return bt.InsertValuesAtStem(key[:31], values[:], resolver, depth) +} + +// Copy creates a deep copy of the node. +func (bt *InternalNode) Copy() BinaryNode { + return &InternalNode{ + left: bt.left.Copy(), + right: bt.right.Copy(), + depth: bt.depth, + } +} + +// Hash returns the hash of the node. +func (bt *InternalNode) Hash() common.Hash { + h := sha256.New() + if bt.left != nil { + h.Write(bt.left.Hash().Bytes()) + } else { + h.Write(zero[:]) + } + if bt.right != nil { + h.Write(bt.right.Hash().Bytes()) + } else { + h.Write(zero[:]) + } + return common.BytesToHash(h.Sum(nil)) +} + +// InsertValuesAtStem inserts a full value group at the given stem in the internal node. +// Already-existing values will be overwritten. +func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn, depth int) (BinaryNode, error) { + var ( + child *BinaryNode + err error + ) + bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1 + if bit == 0 { + child = &bt.left + } else { + child = &bt.right + } + *child, err = (*child).InsertValuesAtStem(stem, values, resolver, depth+1) + return bt, err +} + +// CollectNodes collects all child nodes at a given path, and flushes it +// into the provided node collector. +func (bt *InternalNode) CollectNodes(path []byte, flushfn NodeFlushFn) error { + if bt.left != nil { + var p [256]byte + copy(p[:], path) + childpath := p[:len(path)] + childpath = append(childpath, 0) + if err := bt.left.CollectNodes(childpath, flushfn); err != nil { + return err + } + } + if bt.right != nil { + var p [256]byte + copy(p[:], path) + childpath := p[:len(path)] + childpath = append(childpath, 1) + if err := bt.right.CollectNodes(childpath, flushfn); err != nil { + return err + } + } + flushfn(path, bt) + return nil +} + +// GetHeight returns the height of the node. +func (bt *InternalNode) GetHeight() int { + var ( + leftHeight int + rightHeight int + ) + if bt.left != nil { + leftHeight = bt.left.GetHeight() + } + if bt.right != nil { + rightHeight = bt.right.GetHeight() + } + return 1 + max(leftHeight, rightHeight) +} + +func (bt *InternalNode) toDot(parent, path string) string { + me := fmt.Sprintf("internal%s", path) + ret := fmt.Sprintf("%s [label=\"I: %x\"]\n", me, bt.Hash()) + if len(parent) > 0 { + ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me) + } + + if bt.left != nil { + ret = fmt.Sprintf("%s%s", ret, bt.left.toDot(me, fmt.Sprintf("%s%02x", path, 0))) + } + if bt.right != nil { + ret = fmt.Sprintf("%s%s", ret, bt.right.toDot(me, fmt.Sprintf("%s%02x", path, 1))) + } + return ret +} diff --git a/trie/bintrie/internal_node_test.go b/trie/bintrie/internal_node_test.go new file mode 100644 index 00000000000..158d8b7147d --- /dev/null +++ b/trie/bintrie/internal_node_test.go @@ -0,0 +1,458 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "bytes" + "errors" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +// TestInternalNodeGet tests the Get method +func TestInternalNodeGet(t *testing.T) { + // Create a simple tree structure + leftStem := make([]byte, 31) + rightStem := make([]byte, 31) + rightStem[0] = 0x80 // First bit is 1 + + var leftValues, rightValues [256][]byte + leftValues[0] = common.HexToHash("0x0101").Bytes() + rightValues[0] = common.HexToHash("0x0202").Bytes() + + node := &InternalNode{ + depth: 0, + left: &StemNode{ + Stem: leftStem, + Values: leftValues[:], + depth: 1, + }, + right: &StemNode{ + Stem: rightStem, + Values: rightValues[:], + depth: 1, + }, + } + + // Get value from left subtree + leftKey := make([]byte, 32) + leftKey[31] = 0 + value, err := node.Get(leftKey, nil) + if err != nil { + t.Fatalf("Failed to get left value: %v", err) + } + if !bytes.Equal(value, leftValues[0]) { + t.Errorf("Left value mismatch: expected %x, got %x", leftValues[0], value) + } + + // Get value from right subtree + rightKey := make([]byte, 32) + rightKey[0] = 0x80 + rightKey[31] = 0 + value, err = node.Get(rightKey, nil) + if err != nil { + t.Fatalf("Failed to get right value: %v", err) + } + if !bytes.Equal(value, rightValues[0]) { + t.Errorf("Right value mismatch: expected %x, got %x", rightValues[0], value) + } +} + +// TestInternalNodeGetWithResolver tests Get with HashedNode resolution +func TestInternalNodeGetWithResolver(t *testing.T) { + // Create an internal node with a hashed child + hashedChild := HashedNode(common.HexToHash("0x1234")) + + node := &InternalNode{ + depth: 0, + left: hashedChild, + right: Empty{}, + } + + // Mock resolver that returns a stem node + resolver := func(path []byte, hash common.Hash) ([]byte, error) { + if hash == common.Hash(hashedChild) { + stem := make([]byte, 31) + var values [256][]byte + values[5] = common.HexToHash("0xabcd").Bytes() + stemNode := &StemNode{ + Stem: stem, + Values: values[:], + depth: 1, + } + return SerializeNode(stemNode), nil + } + return nil, errors.New("node not found") + } + + // Get value through the hashed node + key := make([]byte, 32) + key[31] = 5 + value, err := node.Get(key, resolver) + if err != nil { + t.Fatalf("Failed to get value: %v", err) + } + + expectedValue := common.HexToHash("0xabcd").Bytes() + if !bytes.Equal(value, expectedValue) { + t.Errorf("Value mismatch: expected %x, got %x", expectedValue, value) + } +} + +// TestInternalNodeInsert tests the Insert method +func TestInternalNodeInsert(t *testing.T) { + // Start with an internal node with empty children + node := &InternalNode{ + depth: 0, + left: Empty{}, + right: Empty{}, + } + + // Insert a value into the left subtree + leftKey := make([]byte, 32) + leftKey[31] = 10 + leftValue := common.HexToHash("0x0101").Bytes() + + newNode, err := node.Insert(leftKey, leftValue, nil, 0) + if err != nil { + t.Fatalf("Failed to insert: %v", err) + } + + internalNode, ok := newNode.(*InternalNode) + if !ok { + t.Fatalf("Expected InternalNode, got %T", newNode) + } + + // Check that left child is now a StemNode + leftStem, ok := internalNode.left.(*StemNode) + if !ok { + t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left) + } + + // Check the inserted value + if !bytes.Equal(leftStem.Values[10], leftValue) { + t.Errorf("Value mismatch: expected %x, got %x", leftValue, leftStem.Values[10]) + } + + // Right child should still be Empty + _, ok = internalNode.right.(Empty) + if !ok { + t.Errorf("Expected right child to remain Empty, got %T", internalNode.right) + } +} + +// TestInternalNodeCopy tests the Copy method +func TestInternalNodeCopy(t *testing.T) { + // Create an internal node with stem children + leftStem := &StemNode{ + Stem: make([]byte, 31), + Values: make([][]byte, 256), + depth: 1, + } + leftStem.Values[0] = common.HexToHash("0x0101").Bytes() + + rightStem := &StemNode{ + Stem: make([]byte, 31), + Values: make([][]byte, 256), + depth: 1, + } + rightStem.Stem[0] = 0x80 + rightStem.Values[0] = common.HexToHash("0x0202").Bytes() + + node := &InternalNode{ + depth: 0, + left: leftStem, + right: rightStem, + } + + // Create a copy + copied := node.Copy() + copiedInternal, ok := copied.(*InternalNode) + if !ok { + t.Fatalf("Expected InternalNode, got %T", copied) + } + + // Check depth + if copiedInternal.depth != node.depth { + t.Errorf("Depth mismatch: expected %d, got %d", node.depth, copiedInternal.depth) + } + + // Check that children are copied + copiedLeft, ok := copiedInternal.left.(*StemNode) + if !ok { + t.Fatalf("Expected left child to be StemNode, got %T", copiedInternal.left) + } + + copiedRight, ok := copiedInternal.right.(*StemNode) + if !ok { + t.Fatalf("Expected right child to be StemNode, got %T", copiedInternal.right) + } + + // Verify deep copy (children should be different objects) + if copiedLeft == leftStem { + t.Error("Left child not properly copied") + } + if copiedRight == rightStem { + t.Error("Right child not properly copied") + } + + // But values should be equal + if !bytes.Equal(copiedLeft.Values[0], leftStem.Values[0]) { + t.Error("Left child value mismatch after copy") + } + if !bytes.Equal(copiedRight.Values[0], rightStem.Values[0]) { + t.Error("Right child value mismatch after copy") + } +} + +// TestInternalNodeHash tests the Hash method +func TestInternalNodeHash(t *testing.T) { + // Create an internal node + node := &InternalNode{ + depth: 0, + left: HashedNode(common.HexToHash("0x1111")), + right: HashedNode(common.HexToHash("0x2222")), + } + + hash1 := node.Hash() + + // Hash should be deterministic + hash2 := node.Hash() + if hash1 != hash2 { + t.Errorf("Hash not deterministic: %x != %x", hash1, hash2) + } + + // Changing a child should change the hash + node.left = HashedNode(common.HexToHash("0x3333")) + hash3 := node.Hash() + if hash1 == hash3 { + t.Error("Hash didn't change after modifying left child") + } + + // Test with nil children (should use zero hash) + nodeWithNil := &InternalNode{ + depth: 0, + left: nil, + right: HashedNode(common.HexToHash("0x4444")), + } + hashWithNil := nodeWithNil.Hash() + if hashWithNil == (common.Hash{}) { + t.Error("Hash shouldn't be zero even with nil child") + } +} + +// TestInternalNodeGetValuesAtStem tests GetValuesAtStem method +func TestInternalNodeGetValuesAtStem(t *testing.T) { + // Create a tree with values at different stems + leftStem := make([]byte, 31) + rightStem := make([]byte, 31) + rightStem[0] = 0x80 + + var leftValues, rightValues [256][]byte + leftValues[0] = common.HexToHash("0x0101").Bytes() + leftValues[10] = common.HexToHash("0x0102").Bytes() + rightValues[0] = common.HexToHash("0x0201").Bytes() + rightValues[20] = common.HexToHash("0x0202").Bytes() + + node := &InternalNode{ + depth: 0, + left: &StemNode{ + Stem: leftStem, + Values: leftValues[:], + depth: 1, + }, + right: &StemNode{ + Stem: rightStem, + Values: rightValues[:], + depth: 1, + }, + } + + // Get values from left stem + values, err := node.GetValuesAtStem(leftStem, nil) + if err != nil { + t.Fatalf("Failed to get left values: %v", err) + } + if !bytes.Equal(values[0], leftValues[0]) { + t.Error("Left value at index 0 mismatch") + } + if !bytes.Equal(values[10], leftValues[10]) { + t.Error("Left value at index 10 mismatch") + } + + // Get values from right stem + values, err = node.GetValuesAtStem(rightStem, nil) + if err != nil { + t.Fatalf("Failed to get right values: %v", err) + } + if !bytes.Equal(values[0], rightValues[0]) { + t.Error("Right value at index 0 mismatch") + } + if !bytes.Equal(values[20], rightValues[20]) { + t.Error("Right value at index 20 mismatch") + } +} + +// TestInternalNodeInsertValuesAtStem tests InsertValuesAtStem method +func TestInternalNodeInsertValuesAtStem(t *testing.T) { + // Start with an internal node with empty children + node := &InternalNode{ + depth: 0, + left: Empty{}, + right: Empty{}, + } + + // Insert values at a stem in the left subtree + stem := make([]byte, 31) + var values [256][]byte + values[5] = common.HexToHash("0x0505").Bytes() + values[10] = common.HexToHash("0x1010").Bytes() + + newNode, err := node.InsertValuesAtStem(stem, values[:], nil, 0) + if err != nil { + t.Fatalf("Failed to insert values: %v", err) + } + + internalNode, ok := newNode.(*InternalNode) + if !ok { + t.Fatalf("Expected InternalNode, got %T", newNode) + } + + // Check that left child is now a StemNode with the values + leftStem, ok := internalNode.left.(*StemNode) + if !ok { + t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left) + } + + if !bytes.Equal(leftStem.Values[5], values[5]) { + t.Error("Value at index 5 mismatch") + } + if !bytes.Equal(leftStem.Values[10], values[10]) { + t.Error("Value at index 10 mismatch") + } +} + +// TestInternalNodeCollectNodes tests CollectNodes method +func TestInternalNodeCollectNodes(t *testing.T) { + // Create an internal node with two stem children + leftStem := &StemNode{ + Stem: make([]byte, 31), + Values: make([][]byte, 256), + depth: 1, + } + + rightStem := &StemNode{ + Stem: make([]byte, 31), + Values: make([][]byte, 256), + depth: 1, + } + rightStem.Stem[0] = 0x80 + + node := &InternalNode{ + depth: 0, + left: leftStem, + right: rightStem, + } + + var collectedPaths [][]byte + var collectedNodes []BinaryNode + + flushFn := func(path []byte, n BinaryNode) { + pathCopy := make([]byte, len(path)) + copy(pathCopy, path) + collectedPaths = append(collectedPaths, pathCopy) + collectedNodes = append(collectedNodes, n) + } + + err := node.CollectNodes([]byte{1}, flushFn) + if err != nil { + t.Fatalf("Failed to collect nodes: %v", err) + } + + // Should have collected 3 nodes: left stem, right stem, and the internal node itself + if len(collectedNodes) != 3 { + t.Errorf("Expected 3 collected nodes, got %d", len(collectedNodes)) + } + + // Check paths + expectedPaths := [][]byte{ + {1, 0}, // left child + {1, 1}, // right child + {1}, // internal node itself + } + + for i, expectedPath := range expectedPaths { + if !bytes.Equal(collectedPaths[i], expectedPath) { + t.Errorf("Path %d mismatch: expected %v, got %v", i, expectedPath, collectedPaths[i]) + } + } +} + +// TestInternalNodeGetHeight tests GetHeight method +func TestInternalNodeGetHeight(t *testing.T) { + // Create a tree with different heights + // Left subtree: depth 2 (internal -> stem) + // Right subtree: depth 1 (stem) + leftInternal := &InternalNode{ + depth: 1, + left: &StemNode{ + Stem: make([]byte, 31), + Values: make([][]byte, 256), + depth: 2, + }, + right: Empty{}, + } + + rightStem := &StemNode{ + Stem: make([]byte, 31), + Values: make([][]byte, 256), + depth: 1, + } + + node := &InternalNode{ + depth: 0, + left: leftInternal, + right: rightStem, + } + + height := node.GetHeight() + // Height should be max(left height, right height) + 1 + // Left height: 2, Right height: 1, so total: 3 + if height != 3 { + t.Errorf("Expected height 3, got %d", height) + } +} + +// TestInternalNodeDepthTooLarge tests handling of excessive depth +func TestInternalNodeDepthTooLarge(t *testing.T) { + // Create an internal node at max depth + node := &InternalNode{ + depth: 31*8 + 1, + left: Empty{}, + right: Empty{}, + } + + stem := make([]byte, 31) + _, err := node.GetValuesAtStem(stem, nil) + if err == nil { + t.Fatal("Expected error for excessive depth") + } + if err.Error() != "node too deep" { + t.Errorf("Expected 'node too deep' error, got: %v", err) + } +} diff --git a/trie/bintrie/iterator.go b/trie/bintrie/iterator.go new file mode 100644 index 00000000000..a6bab2bcfa9 --- /dev/null +++ b/trie/bintrie/iterator.go @@ -0,0 +1,261 @@ +// Copyright 2025 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "errors" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/trie" +) + +var errIteratorEnd = errors.New("end of iteration") + +type binaryNodeIteratorState struct { + Node BinaryNode + Index int +} + +type binaryNodeIterator struct { + trie *BinaryTrie + current BinaryNode + lastErr error + + stack []binaryNodeIteratorState +} + +func newBinaryNodeIterator(t *BinaryTrie, _ []byte) (trie.NodeIterator, error) { + if t.Hash() == zero { + return &binaryNodeIterator{trie: t, lastErr: errIteratorEnd}, nil + } + it := &binaryNodeIterator{trie: t, current: t.root} + // it.err = it.seek(start) + return it, nil +} + +// Next moves the iterator to the next node. If the parameter is false, any child +// nodes will be skipped. +func (it *binaryNodeIterator) Next(descend bool) bool { + if it.lastErr == errIteratorEnd { + it.lastErr = errIteratorEnd + return false + } + + if len(it.stack) == 0 { + it.stack = append(it.stack, binaryNodeIteratorState{Node: it.trie.root}) + it.current = it.trie.root + + return true + } + + switch node := it.current.(type) { + case *InternalNode: + // index: 0 = nothing visited, 1=left visited, 2=right visited + context := &it.stack[len(it.stack)-1] + + // recurse into both children + if context.Index == 0 { + if _, isempty := node.left.(Empty); node.left != nil && !isempty { + it.stack = append(it.stack, binaryNodeIteratorState{Node: node.left}) + it.current = node.left + return it.Next(descend) + } + + context.Index++ + } + + if context.Index == 1 { + if _, isempty := node.right.(Empty); node.right != nil && !isempty { + it.stack = append(it.stack, binaryNodeIteratorState{Node: node.right}) + it.current = node.right + return it.Next(descend) + } + + context.Index++ + } + + // Reached the end of this node, go back to the parent, if + // this isn't root. + if len(it.stack) == 1 { + it.lastErr = errIteratorEnd + return false + } + it.stack = it.stack[:len(it.stack)-1] + it.current = it.stack[len(it.stack)-1].Node + it.stack[len(it.stack)-1].Index++ + return it.Next(descend) + case *StemNode: + // Look for the next non-empty value + for i := it.stack[len(it.stack)-1].Index; i < 256; i++ { + if node.Values[i] != nil { + it.stack[len(it.stack)-1].Index = i + 1 + return true + } + } + + // go back to parent to get the next leaf + it.stack = it.stack[:len(it.stack)-1] + it.current = it.stack[len(it.stack)-1].Node + it.stack[len(it.stack)-1].Index++ + return it.Next(descend) + case HashedNode: + // resolve the node + data, err := it.trie.nodeResolver(it.Path(), common.Hash(node)) + if err != nil { + panic(err) + } + it.current, err = DeserializeNode(data, len(it.stack)-1) + if err != nil { + panic(err) + } + + // update the stack and parent with the resolved node + it.stack[len(it.stack)-1].Node = it.current + parent := &it.stack[len(it.stack)-2] + if parent.Index == 0 { + parent.Node.(*InternalNode).left = it.current + } else { + parent.Node.(*InternalNode).right = it.current + } + return it.Next(descend) + case Empty: + // do nothing + return false + default: + panic("invalid node type") + } +} + +// Error returns the error status of the iterator. +func (it *binaryNodeIterator) Error() error { + if it.lastErr == errIteratorEnd { + return nil + } + return it.lastErr +} + +// Hash returns the hash of the current node. +func (it *binaryNodeIterator) Hash() common.Hash { + return it.current.Hash() +} + +// Parent returns the hash of the parent of the current node. The hash may be the one +// grandparent if the immediate parent is an internal node with no hash. +func (it *binaryNodeIterator) Parent() common.Hash { + return it.stack[len(it.stack)-1].Node.Hash() +} + +// Path returns the hex-encoded path to the current node. +// Callers must not retain references to the return value after calling Next. +// For leaf nodes, the last element of the path is the 'terminator symbol' 0x10. +func (it *binaryNodeIterator) Path() []byte { + if it.Leaf() { + return it.LeafKey() + } + var path []byte + for i, state := range it.stack { + // skip the last byte + if i >= len(it.stack)-1 { + break + } + path = append(path, byte(state.Index)) + } + return path +} + +// NodeBlob returns the serialized bytes of the current node. +func (it *binaryNodeIterator) NodeBlob() []byte { + return SerializeNode(it.current) +} + +// Leaf returns true iff the current node is a leaf node. +func (it *binaryNodeIterator) Leaf() bool { + _, ok := it.current.(*StemNode) + return ok +} + +// LeafKey returns the key of the leaf. The method panics if the iterator is not +// positioned at a leaf. Callers must not retain references to the value after +// calling Next. +func (it *binaryNodeIterator) LeafKey() []byte { + leaf, ok := it.current.(*StemNode) + if !ok { + panic("Leaf() called on an binary node iterator not at a leaf location") + } + return leaf.Key(it.stack[len(it.stack)-1].Index - 1) +} + +// LeafBlob returns the content of the leaf. The method panics if the iterator +// is not positioned at a leaf. Callers must not retain references to the value +// after calling Next. +func (it *binaryNodeIterator) LeafBlob() []byte { + leaf, ok := it.current.(*StemNode) + if !ok { + panic("LeafBlob() called on an binary node iterator not at a leaf location") + } + return leaf.Values[it.stack[len(it.stack)-1].Index-1] +} + +// LeafProof returns the Merkle proof of the leaf. The method panics if the +// iterator is not positioned at a leaf. Callers must not retain references +// to the value after calling Next. +func (it *binaryNodeIterator) LeafProof() [][]byte { + sn, ok := it.current.(*StemNode) + if !ok { + panic("LeafProof() called on an binary node iterator not at a leaf location") + } + + proof := make([][]byte, 0, len(it.stack)+NodeWidth) + + // Build proof by walking up the stack and collecting sibling hashes + for i := range it.stack[:len(it.stack)-2] { + state := it.stack[i] + internalNode := state.Node.(*InternalNode) // should panic if the node isn't an InternalNode + + // Add the sibling hash to the proof + if state.Index == 0 { + // We came from left, so include right sibling + proof = append(proof, internalNode.right.Hash().Bytes()) + } else { + // We came from right, so include left sibling + proof = append(proof, internalNode.left.Hash().Bytes()) + } + } + + // Add the stem and siblings + proof = append(proof, sn.Stem) + for _, v := range sn.Values { + proof = append(proof, v) + } + + return proof +} + +// AddResolver sets an intermediate database to use for looking up trie nodes +// before reaching into the real persistent layer. +// +// This is not required for normal operation, rather is an optimization for +// cases where trie nodes can be recovered from some external mechanism without +// reading from disk. In those cases, this resolver allows short circuiting +// accesses and returning them from memory. +// +// Before adding a similar mechanism to any other place in Geth, consider +// making trie.Database an interface and wrapping at that level. It's a huge +// refactor, but it could be worth it if another occurrence arises. +func (it *binaryNodeIterator) AddResolver(trie.NodeResolver) { + // Not implemented, but should not panic +} diff --git a/trie/bintrie/iterator_test.go b/trie/bintrie/iterator_test.go new file mode 100644 index 00000000000..8773e9e0c54 --- /dev/null +++ b/trie/bintrie/iterator_test.go @@ -0,0 +1,83 @@ +// Copyright 2025 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/triedb" + "github.com/ethereum/go-ethereum/triedb/hashdb" + "github.com/ethereum/go-ethereum/triedb/pathdb" + "github.com/holiman/uint256" +) + +func newTestDatabase(diskdb ethdb.Database, scheme string) *triedb.Database { + config := &triedb.Config{Preimages: true} + if scheme == rawdb.HashScheme { + config.HashDB = &hashdb.Config{CleanCacheSize: 0} + } else { + config.PathDB = &pathdb.Config{TrieCleanSize: 0, StateCleanSize: 0} + } + return triedb.NewDatabase(diskdb, config) +} + +func TestBinaryIterator(t *testing.T) { + trie, err := NewBinaryTrie(types.EmptyVerkleHash, newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.PathScheme)) + if err != nil { + t.Fatal(err) + } + account0 := &types.StateAccount{ + Nonce: 1, + Balance: uint256.NewInt(2), + Root: types.EmptyRootHash, + CodeHash: nil, + } + // NOTE: the code size isn't written to the trie via TryUpdateAccount + // so it will be missing from the test nodes. + trie.UpdateAccount(common.Address{}, account0, 0) + account1 := &types.StateAccount{ + Nonce: 1337, + Balance: uint256.NewInt(2000), + Root: types.EmptyRootHash, + CodeHash: nil, + } + // This address is meant to hash to a value that has the same first byte as 0xbf + var clash = common.HexToAddress("69fd8034cdb20934dedffa7dccb4fb3b8062a8be") + trie.UpdateAccount(clash, account1, 0) + + // Manually go over every node to check that we get all + // the correct nodes. + it, err := trie.NodeIterator(nil) + if err != nil { + t.Fatal(err) + } + var leafcount int + for it.Next(true) { + t.Logf("Node: %x", it.Path()) + if it.Leaf() { + leafcount++ + t.Logf("\tLeaf: %x", it.LeafKey()) + } + } + if leafcount != 2 { + t.Fatalf("invalid leaf count: %d != 6", leafcount) + } +} diff --git a/trie/bintrie/key_encoding.go b/trie/bintrie/key_encoding.go new file mode 100644 index 00000000000..13c20573710 --- /dev/null +++ b/trie/bintrie/key_encoding.go @@ -0,0 +1,79 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "bytes" + "crypto/sha256" + + "github.com/ethereum/go-ethereum/common" + "github.com/holiman/uint256" +) + +const ( + BasicDataLeafKey = 0 + CodeHashLeafKey = 1 + BasicDataCodeSizeOffset = 5 + BasicDataNonceOffset = 8 + BasicDataBalanceOffset = 16 +) + +var ( + zeroHash = common.Hash{} + codeOffset = uint256.NewInt(128) +) + +func GetBinaryTreeKey(addr common.Address, key []byte) []byte { + hasher := sha256.New() + hasher.Write(zeroHash[:12]) + hasher.Write(addr[:]) + hasher.Write(key[:31]) + k := hasher.Sum(nil) + k[31] = key[31] + return k +} + +func GetBinaryTreeKeyCodeHash(addr common.Address) []byte { + var k [32]byte + k[31] = CodeHashLeafKey + return GetBinaryTreeKey(addr, k[:]) +} + +func GetBinaryTreeKeyStorageSlot(address common.Address, key []byte) []byte { + var k [32]byte + + // Case when the key belongs to the account header + if bytes.Equal(key[:31], zeroHash[:31]) && key[31] < 64 { + k[31] = 64 + key[31] + return GetBinaryTreeKey(address, k[:]) + } + + // Set the main storage offset + // note that the first 64 bytes of the main offset storage + // are unreachable, which is consistent with the spec and + // what verkle does. + k[0] = 1 // 1 << 248 + copy(k[1:], key[:31]) + k[31] = key[31] + + return GetBinaryTreeKey(address, k[:]) +} + +func GetBinaryTreeKeyCodeChunk(address common.Address, chunknr *uint256.Int) []byte { + chunkOffset := new(uint256.Int).Add(codeOffset, chunknr).Bytes() + return GetBinaryTreeKey(address, chunkOffset) +} diff --git a/trie/bintrie/stem_node.go b/trie/bintrie/stem_node.go new file mode 100644 index 00000000000..50c06c9761e --- /dev/null +++ b/trie/bintrie/stem_node.go @@ -0,0 +1,213 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "bytes" + "crypto/sha256" + "errors" + "fmt" + "slices" + + "github.com/ethereum/go-ethereum/common" +) + +// StemNode represents a group of `NodeWith` values sharing the same stem. +type StemNode struct { + Stem []byte // Stem path to get to 256 values + Values [][]byte // All values, indexed by the last byte of the key. + depth int // Depth of the node +} + +// Get retrieves the value for the given key. +func (bt *StemNode) Get(key []byte, _ NodeResolverFn) ([]byte, error) { + panic("this should not be called directly") +} + +// Insert inserts a new key-value pair into the node. +func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int) (BinaryNode, error) { + if !bytes.Equal(bt.Stem, key[:31]) { + bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1 + + n := &InternalNode{depth: bt.depth} + bt.depth++ + var child, other *BinaryNode + if bitStem == 0 { + n.left = bt + child = &n.left + other = &n.right + } else { + n.right = bt + child = &n.right + other = &n.left + } + + bitKey := key[n.depth/8] >> (7 - (n.depth % 8)) & 1 + if bitKey == bitStem { + var err error + *child, err = (*child).Insert(key, value, nil, depth+1) + if err != nil { + return n, fmt.Errorf("insert error: %w", err) + } + *other = Empty{} + } else { + var values [256][]byte + values[key[31]] = value + *other = &StemNode{ + Stem: slices.Clone(key[:31]), + Values: values[:], + depth: depth + 1, + } + } + return n, nil + } + if len(value) != 32 { + return bt, errors.New("invalid insertion: value length") + } + bt.Values[key[31]] = value + return bt, nil +} + +// Copy creates a deep copy of the node. +func (bt *StemNode) Copy() BinaryNode { + var values [256][]byte + for i, v := range bt.Values { + values[i] = slices.Clone(v) + } + return &StemNode{ + Stem: slices.Clone(bt.Stem), + Values: values[:], + depth: bt.depth, + } +} + +// GetHeight returns the height of the node. +func (bt *StemNode) GetHeight() int { + return 1 +} + +// Hash returns the hash of the node. +func (bt *StemNode) Hash() common.Hash { + var data [NodeWidth]common.Hash + for i, v := range bt.Values { + if v != nil { + h := sha256.Sum256(v) + data[i] = common.BytesToHash(h[:]) + } + } + + h := sha256.New() + for level := 1; level <= 8; level++ { + for i := range NodeWidth / (1 << level) { + h.Reset() + + if data[i*2] == (common.Hash{}) && data[i*2+1] == (common.Hash{}) { + data[i] = common.Hash{} + continue + } + + h.Write(data[i*2][:]) + h.Write(data[i*2+1][:]) + data[i] = common.Hash(h.Sum(nil)) + } + } + + h.Reset() + h.Write(bt.Stem) + h.Write([]byte{0}) + h.Write(data[0][:]) + return common.BytesToHash(h.Sum(nil)) +} + +// CollectNodes collects all child nodes at a given path, and flushes it +// into the provided node collector. +func (bt *StemNode) CollectNodes(path []byte, flush NodeFlushFn) error { + flush(path, bt) + return nil +} + +// GetValuesAtStem retrieves the group of values located at the given stem key. +func (bt *StemNode) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) { + return bt.Values[:], nil +} + +// InsertValuesAtStem inserts a full value group at the given stem in the internal node. +// Already-existing values will be overwritten. +func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn, depth int) (BinaryNode, error) { + if !bytes.Equal(bt.Stem, key[:31]) { + bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1 + + n := &InternalNode{depth: bt.depth} + bt.depth++ + var child, other *BinaryNode + if bitStem == 0 { + n.left = bt + child = &n.left + other = &n.right + } else { + n.right = bt + child = &n.right + other = &n.left + } + + bitKey := key[n.depth/8] >> (7 - (n.depth % 8)) & 1 + if bitKey == bitStem { + var err error + *child, err = (*child).InsertValuesAtStem(key, values, nil, depth+1) + if err != nil { + return n, fmt.Errorf("insert error: %w", err) + } + *other = Empty{} + } else { + *other = &StemNode{ + Stem: slices.Clone(key[:31]), + Values: values, + depth: n.depth + 1, + } + } + return n, nil + } + + // same stem, just merge the two value lists + for i, v := range values { + if v != nil { + bt.Values[i] = v + } + } + return bt, nil +} + +func (bt *StemNode) toDot(parent, path string) string { + me := fmt.Sprintf("stem%s", path) + ret := fmt.Sprintf("%s [label=\"stem=%x c=%x\"]\n", me, bt.Stem, bt.Hash()) + ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me) + for i, v := range bt.Values { + if v != nil { + ret = fmt.Sprintf("%s%s%x [label=\"%x\"]\n", ret, me, i, v) + ret = fmt.Sprintf("%s%s -> %s%x\n", ret, me, me, i) + } + } + return ret +} + +// Key returns the full key for the given index. +func (bt *StemNode) Key(i int) []byte { + var ret [32]byte + copy(ret[:], bt.Stem) + ret[StemSize] = byte(i) + return ret[:] +} diff --git a/trie/bintrie/stem_node_test.go b/trie/bintrie/stem_node_test.go new file mode 100644 index 00000000000..e0ffd5c3c84 --- /dev/null +++ b/trie/bintrie/stem_node_test.go @@ -0,0 +1,373 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "bytes" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +// TestStemNodeInsertSameStem tests inserting values with the same stem +func TestStemNodeInsertSameStem(t *testing.T) { + stem := make([]byte, 31) + for i := range stem { + stem[i] = byte(i) + } + + var values [256][]byte + values[0] = common.HexToHash("0x0101").Bytes() + + node := &StemNode{ + Stem: stem, + Values: values[:], + depth: 0, + } + + // Insert another value with the same stem but different last byte + key := make([]byte, 32) + copy(key[:31], stem) + key[31] = 10 + value := common.HexToHash("0x0202").Bytes() + + newNode, err := node.Insert(key, value, nil, 0) + if err != nil { + t.Fatalf("Failed to insert: %v", err) + } + + // Should still be a StemNode + stemNode, ok := newNode.(*StemNode) + if !ok { + t.Fatalf("Expected StemNode, got %T", newNode) + } + + // Check that both values are present + if !bytes.Equal(stemNode.Values[0], values[0]) { + t.Errorf("Value at index 0 mismatch") + } + if !bytes.Equal(stemNode.Values[10], value) { + t.Errorf("Value at index 10 mismatch") + } +} + +// TestStemNodeInsertDifferentStem tests inserting values with different stems +func TestStemNodeInsertDifferentStem(t *testing.T) { + stem1 := make([]byte, 31) + for i := range stem1 { + stem1[i] = 0x00 + } + + var values [256][]byte + values[0] = common.HexToHash("0x0101").Bytes() + + node := &StemNode{ + Stem: stem1, + Values: values[:], + depth: 0, + } + + // Insert with a different stem (first bit different) + key := make([]byte, 32) + key[0] = 0x80 // First bit is 1 instead of 0 + value := common.HexToHash("0x0202").Bytes() + + newNode, err := node.Insert(key, value, nil, 0) + if err != nil { + t.Fatalf("Failed to insert: %v", err) + } + + // Should now be an InternalNode + internalNode, ok := newNode.(*InternalNode) + if !ok { + t.Fatalf("Expected InternalNode, got %T", newNode) + } + + // Check depth + if internalNode.depth != 0 { + t.Errorf("Expected depth 0, got %d", internalNode.depth) + } + + // Original stem should be on the left (bit 0) + leftStem, ok := internalNode.left.(*StemNode) + if !ok { + t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left) + } + if !bytes.Equal(leftStem.Stem, stem1) { + t.Errorf("Left stem mismatch") + } + + // New stem should be on the right (bit 1) + rightStem, ok := internalNode.right.(*StemNode) + if !ok { + t.Fatalf("Expected right child to be StemNode, got %T", internalNode.right) + } + if !bytes.Equal(rightStem.Stem, key[:31]) { + t.Errorf("Right stem mismatch") + } +} + +// TestStemNodeInsertInvalidValueLength tests inserting value with invalid length +func TestStemNodeInsertInvalidValueLength(t *testing.T) { + stem := make([]byte, 31) + var values [256][]byte + + node := &StemNode{ + Stem: stem, + Values: values[:], + depth: 0, + } + + // Try to insert value with wrong length + key := make([]byte, 32) + copy(key[:31], stem) + invalidValue := []byte{1, 2, 3} // Not 32 bytes + + _, err := node.Insert(key, invalidValue, nil, 0) + if err == nil { + t.Fatal("Expected error for invalid value length") + } + + if err.Error() != "invalid insertion: value length" { + t.Errorf("Expected 'invalid insertion: value length' error, got: %v", err) + } +} + +// TestStemNodeCopy tests the Copy method +func TestStemNodeCopy(t *testing.T) { + stem := make([]byte, 31) + for i := range stem { + stem[i] = byte(i) + } + + var values [256][]byte + values[0] = common.HexToHash("0x0101").Bytes() + values[255] = common.HexToHash("0x0202").Bytes() + + node := &StemNode{ + Stem: stem, + Values: values[:], + depth: 10, + } + + // Create a copy + copied := node.Copy() + copiedStem, ok := copied.(*StemNode) + if !ok { + t.Fatalf("Expected StemNode, got %T", copied) + } + + // Check that values are equal but not the same slice + if !bytes.Equal(copiedStem.Stem, node.Stem) { + t.Errorf("Stem mismatch after copy") + } + if &copiedStem.Stem[0] == &node.Stem[0] { + t.Error("Stem slice not properly cloned") + } + + // Check values + if !bytes.Equal(copiedStem.Values[0], node.Values[0]) { + t.Errorf("Value at index 0 mismatch after copy") + } + if !bytes.Equal(copiedStem.Values[255], node.Values[255]) { + t.Errorf("Value at index 255 mismatch after copy") + } + + // Check that value slices are cloned + if copiedStem.Values[0] != nil && &copiedStem.Values[0][0] == &node.Values[0][0] { + t.Error("Value slice not properly cloned") + } + + // Check depth + if copiedStem.depth != node.depth { + t.Errorf("Depth mismatch: expected %d, got %d", node.depth, copiedStem.depth) + } +} + +// TestStemNodeHash tests the Hash method +func TestStemNodeHash(t *testing.T) { + stem := make([]byte, 31) + var values [256][]byte + values[0] = common.HexToHash("0x0101").Bytes() + + node := &StemNode{ + Stem: stem, + Values: values[:], + depth: 0, + } + + hash1 := node.Hash() + + // Hash should be deterministic + hash2 := node.Hash() + if hash1 != hash2 { + t.Errorf("Hash not deterministic: %x != %x", hash1, hash2) + } + + // Changing a value should change the hash + node.Values[1] = common.HexToHash("0x0202").Bytes() + hash3 := node.Hash() + if hash1 == hash3 { + t.Error("Hash didn't change after modifying values") + } +} + +// TestStemNodeGetValuesAtStem tests GetValuesAtStem method +func TestStemNodeGetValuesAtStem(t *testing.T) { + stem := make([]byte, 31) + for i := range stem { + stem[i] = byte(i) + } + + var values [256][]byte + values[0] = common.HexToHash("0x0101").Bytes() + values[10] = common.HexToHash("0x0202").Bytes() + values[255] = common.HexToHash("0x0303").Bytes() + + node := &StemNode{ + Stem: stem, + Values: values[:], + depth: 0, + } + + // GetValuesAtStem with matching stem + retrievedValues, err := node.GetValuesAtStem(stem, nil) + if err != nil { + t.Fatalf("Failed to get values: %v", err) + } + + // Check that all values match + for i := 0; i < 256; i++ { + if !bytes.Equal(retrievedValues[i], values[i]) { + t.Errorf("Value mismatch at index %d", i) + } + } + + // GetValuesAtStem with different stem also returns the same values + // (implementation ignores the stem parameter) + differentStem := make([]byte, 31) + differentStem[0] = 0xFF + + retrievedValues2, err := node.GetValuesAtStem(differentStem, nil) + if err != nil { + t.Fatalf("Failed to get values with different stem: %v", err) + } + + // Should still return the same values (stem is ignored) + for i := 0; i < 256; i++ { + if !bytes.Equal(retrievedValues2[i], values[i]) { + t.Errorf("Value mismatch at index %d with different stem", i) + } + } +} + +// TestStemNodeInsertValuesAtStem tests InsertValuesAtStem method +func TestStemNodeInsertValuesAtStem(t *testing.T) { + stem := make([]byte, 31) + var values [256][]byte + values[0] = common.HexToHash("0x0101").Bytes() + + node := &StemNode{ + Stem: stem, + Values: values[:], + depth: 0, + } + + // Insert new values at the same stem + var newValues [256][]byte + newValues[1] = common.HexToHash("0x0202").Bytes() + newValues[2] = common.HexToHash("0x0303").Bytes() + + newNode, err := node.InsertValuesAtStem(stem, newValues[:], nil, 0) + if err != nil { + t.Fatalf("Failed to insert values: %v", err) + } + + stemNode, ok := newNode.(*StemNode) + if !ok { + t.Fatalf("Expected StemNode, got %T", newNode) + } + + // Check that all values are present + if !bytes.Equal(stemNode.Values[0], values[0]) { + t.Error("Original value at index 0 missing") + } + if !bytes.Equal(stemNode.Values[1], newValues[1]) { + t.Error("New value at index 1 missing") + } + if !bytes.Equal(stemNode.Values[2], newValues[2]) { + t.Error("New value at index 2 missing") + } +} + +// TestStemNodeGetHeight tests GetHeight method +func TestStemNodeGetHeight(t *testing.T) { + node := &StemNode{ + Stem: make([]byte, 31), + Values: make([][]byte, 256), + depth: 0, + } + + height := node.GetHeight() + if height != 1 { + t.Errorf("Expected height 1, got %d", height) + } +} + +// TestStemNodeCollectNodes tests CollectNodes method +func TestStemNodeCollectNodes(t *testing.T) { + stem := make([]byte, 31) + var values [256][]byte + values[0] = common.HexToHash("0x0101").Bytes() + + node := &StemNode{ + Stem: stem, + Values: values[:], + depth: 0, + } + + var collectedPaths [][]byte + var collectedNodes []BinaryNode + + flushFn := func(path []byte, n BinaryNode) { + // Make a copy of the path + pathCopy := make([]byte, len(path)) + copy(pathCopy, path) + collectedPaths = append(collectedPaths, pathCopy) + collectedNodes = append(collectedNodes, n) + } + + err := node.CollectNodes([]byte{0, 1, 0}, flushFn) + if err != nil { + t.Fatalf("Failed to collect nodes: %v", err) + } + + // Should have collected one node (itself) + if len(collectedNodes) != 1 { + t.Errorf("Expected 1 collected node, got %d", len(collectedNodes)) + } + + // Check that the collected node is the same + if collectedNodes[0] != node { + t.Error("Collected node doesn't match original") + } + + // Check the path + if !bytes.Equal(collectedPaths[0], []byte{0, 1, 0}) { + t.Errorf("Path mismatch: expected [0, 1, 0], got %v", collectedPaths[0]) + } +} diff --git a/trie/bintrie/trie.go b/trie/bintrie/trie.go new file mode 100644 index 00000000000..0a8bd325f58 --- /dev/null +++ b/trie/bintrie/trie.go @@ -0,0 +1,353 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/trie" + "github.com/ethereum/go-ethereum/trie/trienode" + "github.com/ethereum/go-ethereum/triedb/database" + "github.com/holiman/uint256" +) + +var errInvalidRootType = errors.New("invalid root type") + +// NewBinaryNode creates a new empty binary trie +func NewBinaryNode() BinaryNode { + return Empty{} +} + +// BinaryTrie is the implementation of https://eips.ethereum.org/EIPS/eip-7864. +type BinaryTrie struct { + root BinaryNode + reader *trie.Reader + tracer *trie.PrevalueTracer +} + +// ToDot converts the binary trie to a DOT language representation. Useful for debugging. +func (t *BinaryTrie) ToDot() string { + t.root.Hash() + return ToDot(t.root) +} + +// NewBinaryTrie creates a new binary trie. +func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, error) { + reader, err := trie.NewReader(root, common.Hash{}, db) + if err != nil { + return nil, err + } + t := &BinaryTrie{ + root: NewBinaryNode(), + reader: reader, + tracer: trie.NewPrevalueTracer(), + } + // Parse the root node if it's not empty + if root != types.EmptyBinaryHash && root != types.EmptyRootHash { + blob, err := t.nodeResolver(nil, root) + if err != nil { + return nil, err + } + node, err := DeserializeNode(blob, 0) + if err != nil { + return nil, err + } + t.root = node + } + return t, nil +} + +// nodeResolver is a node resolver that reads nodes from the flatdb. +func (t *BinaryTrie) nodeResolver(path []byte, hash common.Hash) ([]byte, error) { + // empty nodes will be serialized as common.Hash{}, so capture + // this special use case. + if hash == (common.Hash{}) { + return nil, nil // empty node + } + blob, err := t.reader.Node(path, hash) + if err != nil { + return nil, err + } + t.tracer.Put(path, blob) + return blob, nil +} + +// GetKey returns the sha3 preimage of a hashed key that was previously used +// to store a value. +func (t *BinaryTrie) GetKey(key []byte) []byte { + return key +} + +// GetWithHashedKey returns the value, assuming that the key has already +// been hashed. +func (t *BinaryTrie) GetWithHashedKey(key []byte) ([]byte, error) { + return t.root.Get(key, t.nodeResolver) +} + +// GetAccount returns the account information for the given address. +func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error) { + var ( + values [][]byte + err error + acc = &types.StateAccount{} + key = GetBinaryTreeKey(addr, zero[:]) + ) + switch r := t.root.(type) { + case *InternalNode: + values, err = r.GetValuesAtStem(key[:31], t.nodeResolver) + case *StemNode: + values = r.Values + case Empty: + return nil, nil + default: + // This will cover HashedNode but that should be fine since the + // root node should always be resolved. + return nil, errInvalidRootType + } + if err != nil { + return nil, fmt.Errorf("GetAccount (%x) error: %v", addr, err) + } + + // The following code is required for the MPT->Binary conversion. + // An account can be partially migrated, where storage slots were moved to the binary + // but not yet the account. This means some account information as (header) storage slots + // are in the binary trie but basic account information must be read in the base tree (MPT). + // TODO: we can simplify this logic depending if the conversion is in progress or finished. + emptyAccount := true + for i := 0; values != nil && i <= CodeHashLeafKey && emptyAccount; i++ { + emptyAccount = emptyAccount && values[i] == nil + } + if emptyAccount { + return nil, nil + } + + // If the account has been deleted, then values[10] will be 0 and not nil. If it has + // been recreated after that, then its code keccak will NOT be 0. So return `nil` if + // the nonce, and values[10], and code keccak is 0. + if bytes.Equal(values[BasicDataLeafKey], zero[:]) && len(values) > 10 && len(values[10]) > 0 && bytes.Equal(values[CodeHashLeafKey], zero[:]) { + return nil, nil + } + + acc.Nonce = binary.BigEndian.Uint64(values[BasicDataLeafKey][BasicDataNonceOffset:]) + var balance [16]byte + copy(balance[:], values[BasicDataLeafKey][BasicDataBalanceOffset:]) + acc.Balance = new(uint256.Int).SetBytes(balance[:]) + acc.CodeHash = values[CodeHashLeafKey] + + return acc, nil +} + +// GetStorage returns the value for key stored in the trie. The value bytes must +// not be modified by the caller. If a node was not found in the database, a +// trie.MissingNodeError is returned. +func (t *BinaryTrie) GetStorage(addr common.Address, key []byte) ([]byte, error) { + return t.root.Get(GetBinaryTreeKey(addr, key), t.nodeResolver) +} + +// UpdateAccount updates the account information for the given address. +func (t *BinaryTrie) UpdateAccount(addr common.Address, acc *types.StateAccount, codeLen int) error { + var ( + err error + basicData [32]byte + values = make([][]byte, NodeWidth) + stem = GetBinaryTreeKey(addr, zero[:]) + ) + binary.BigEndian.PutUint32(basicData[BasicDataCodeSizeOffset-1:], uint32(codeLen)) + binary.BigEndian.PutUint64(basicData[BasicDataNonceOffset:], acc.Nonce) + + // Because the balance is a max of 16 bytes, truncate + // the extra values. This happens in devmode, where + // 0xff**32 is allocated to the developer account. + balanceBytes := acc.Balance.Bytes() + // TODO: reduce the size of the allocation in devmode, then panic instead + // of truncating. + if len(balanceBytes) > 16 { + balanceBytes = balanceBytes[16:] + } + copy(basicData[32-len(balanceBytes):], balanceBytes[:]) + values[BasicDataLeafKey] = basicData[:] + values[CodeHashLeafKey] = acc.CodeHash[:] + + t.root, err = t.root.InsertValuesAtStem(stem, values, t.nodeResolver, 0) + return err +} + +// UpdateStem updates the values for the given stem key. +func (t *BinaryTrie) UpdateStem(key []byte, values [][]byte) error { + var err error + t.root, err = t.root.InsertValuesAtStem(key, values, t.nodeResolver, 0) + return err +} + +// UpdateStorage associates key with value in the trie. If value has length zero, any +// existing value is deleted from the trie. The value bytes must not be modified +// by the caller while they are stored in the trie. If a node was not found in the +// database, a trie.MissingNodeError is returned. +func (t *BinaryTrie) UpdateStorage(address common.Address, key, value []byte) error { + k := GetBinaryTreeKeyStorageSlot(address, key) + var v [32]byte + if len(value) >= 32 { + copy(v[:], value[:32]) + } else { + copy(v[32-len(value):], value[:]) + } + root, err := t.root.Insert(k, v[:], t.nodeResolver, 0) + if err != nil { + return fmt.Errorf("UpdateStorage (%x) error: %v", address, err) + } + t.root = root + return nil +} + +// DeleteAccount is a no-op as it is disabled in stateless. +func (t *BinaryTrie) DeleteAccount(addr common.Address) error { + return nil +} + +// DeleteStorage removes any existing value for key from the trie. If a node was not +// found in the database, a trie.MissingNodeError is returned. +func (t *BinaryTrie) DeleteStorage(addr common.Address, key []byte) error { + k := GetBinaryTreeKey(addr, key) + var zero [32]byte + root, err := t.root.Insert(k, zero[:], t.nodeResolver, 0) + if err != nil { + return fmt.Errorf("DeleteStorage (%x) error: %v", addr, err) + } + t.root = root + return nil +} + +// Hash returns the root hash of the trie. It does not write to the database and +// can be used even if the trie doesn't have one. +func (t *BinaryTrie) Hash() common.Hash { + return t.root.Hash() +} + +// Commit writes all nodes to the trie's memory database, tracking the internal +// and external (for account tries) references. +func (t *BinaryTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) { + root := t.root.(*InternalNode) + nodeset := trienode.NewNodeSet(common.Hash{}) + + err := root.CollectNodes(nil, func(path []byte, node BinaryNode) { + serialized := SerializeNode(node) + nodeset.AddNode(path, trienode.NewNodeWithPrev(common.Hash{}, serialized, t.tracer.Get(path))) + }) + if err != nil { + panic(fmt.Errorf("CollectNodes failed: %v", err)) + } + // Serialize root commitment form + return t.Hash(), nodeset +} + +// NodeIterator returns an iterator that returns nodes of the trie. Iteration +// starts at the key after the given start key. +func (t *BinaryTrie) NodeIterator(startKey []byte) (trie.NodeIterator, error) { + return newBinaryNodeIterator(t, nil) +} + +// Prove constructs a Merkle proof for key. The result contains all encoded nodes +// on the path to the value at key. The value itself is also included in the last +// node and can be retrieved by verifying the proof. +// +// If the trie does not contain a value for key, the returned proof contains all +// nodes of the longest existing prefix of the key (at least the root), ending +// with the node that proves the absence of the key. +func (t *BinaryTrie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { + panic("not implemented") +} + +// Copy creates a deep copy of the trie. +func (t *BinaryTrie) Copy() *BinaryTrie { + return &BinaryTrie{ + root: t.root.Copy(), + reader: t.reader, + tracer: t.tracer.Copy(), + } +} + +// IsVerkle returns true if the trie is a Verkle tree. +func (t *BinaryTrie) IsVerkle() bool { + // TODO @gballet This is technically NOT a verkle tree, but it has the same + // behavior and basic structure, so for all intents and purposes, it can be + // treated as such. Rename this when verkle gets removed. + return true +} + +// UpdateContractCode updates the contract code into the trie. +// +// Note: the basic data leaf needs to have been previously created for this to work +func (t *BinaryTrie) UpdateContractCode(addr common.Address, codeHash common.Hash, code []byte) error { + var ( + chunks = trie.ChunkifyCode(code) + values [][]byte + key []byte + err error + ) + for i, chunknr := 0, uint64(0); i < len(chunks); i, chunknr = i+32, chunknr+1 { + groupOffset := (chunknr + 128) % 256 + if groupOffset == 0 /* start of new group */ || chunknr == 0 /* first chunk in header group */ { + values = make([][]byte, NodeWidth) + var offset [32]byte + binary.LittleEndian.PutUint64(offset[24:], chunknr+128) + key = GetBinaryTreeKey(addr, offset[:]) + } + values[groupOffset] = chunks[i : i+32] + + if groupOffset == 255 || len(chunks)-i <= 32 { + err = t.UpdateStem(key[:31], values) + + if err != nil { + return fmt.Errorf("UpdateContractCode (addr=%x) error: %w", addr[:], err) + } + } + } + return nil +} + +// PrefetchAccount attempts to resolve specific accounts from the database +// to accelerate subsequent trie operations. +func (t *BinaryTrie) PrefetchAccount(addresses []common.Address) error { + for _, addr := range addresses { + if _, err := t.GetAccount(addr); err != nil { + return err + } + } + return nil +} + +// PrefetchStorage attempts to resolve specific storage slots from the database +// to accelerate subsequent trie operations. +func (t *BinaryTrie) PrefetchStorage(addr common.Address, keys [][]byte) error { + for _, key := range keys { + if _, err := t.GetStorage(addr, key); err != nil { + return err + } + } + return nil +} + +// Witness returns a set containing all trie nodes that have been accessed. +func (t *BinaryTrie) Witness() map[string][]byte { + panic("not implemented") +} diff --git a/trie/bintrie/trie_test.go b/trie/bintrie/trie_test.go new file mode 100644 index 00000000000..84f76895494 --- /dev/null +++ b/trie/bintrie/trie_test.go @@ -0,0 +1,197 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "bytes" + "encoding/binary" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +var ( + zeroKey = [32]byte{} + oneKey = common.HexToHash("0101010101010101010101010101010101010101010101010101010101010101") + twoKey = common.HexToHash("0202020202020202020202020202020202020202020202020202020202020202") + threeKey = common.HexToHash("0303030303030303030303030303030303030303030303030303030303030303") + fourKey = common.HexToHash("0404040404040404040404040404040404040404040404040404040404040404") + ffKey = common.HexToHash("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") +) + +func TestSingleEntry(t *testing.T) { + tree := NewBinaryNode() + tree, err := tree.Insert(zeroKey[:], oneKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + if tree.GetHeight() != 1 { + t.Fatal("invalid depth") + } + expected := common.HexToHash("aab1060e04cb4f5dc6f697ae93156a95714debbf77d54238766adc5709282b6f") + got := tree.Hash() + if got != expected { + t.Fatalf("invalid tree root, got %x, want %x", got, expected) + } +} + +func TestTwoEntriesDiffFirstBit(t *testing.T) { + var err error + tree := NewBinaryNode() + tree, err = tree.Insert(zeroKey[:], oneKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(), twoKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + if tree.GetHeight() != 2 { + t.Fatal("invalid height") + } + if tree.Hash() != common.HexToHash("dfc69c94013a8b3c65395625a719a87534a7cfd38719251ad8c8ea7fe79f065e") { + t.Fatal("invalid tree root") + } +} + +func TestOneStemColocatedValues(t *testing.T) { + var err error + tree := NewBinaryNode() + tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000009").Bytes(), threeKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + tree, err = tree.Insert(common.HexToHash("00000000000000000000000000000000000000000000000000000000000000FF").Bytes(), fourKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + if tree.GetHeight() != 1 { + t.Fatal("invalid height") + } +} + +func TestTwoStemColocatedValues(t *testing.T) { + var err error + tree := NewBinaryNode() + // stem: 0...0 + tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + // stem: 10...0 + tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + if tree.GetHeight() != 2 { + t.Fatal("invalid height") + } +} + +func TestTwoKeysMatchFirst42Bits(t *testing.T) { + var err error + tree := NewBinaryNode() + // key1 and key 2 have the same prefix of 42 bits (b0*42+b1+b1) and differ after. + key1 := common.HexToHash("0000000000C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0").Bytes() + key2 := common.HexToHash("0000000000E00000000000000000000000000000000000000000000000000000").Bytes() + tree, err = tree.Insert(key1, oneKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + tree, err = tree.Insert(key2, twoKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + if tree.GetHeight() != 1+42+1 { + t.Fatal("invalid height") + } +} +func TestInsertDuplicateKey(t *testing.T) { + var err error + tree := NewBinaryNode() + tree, err = tree.Insert(oneKey[:], oneKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + tree, err = tree.Insert(oneKey[:], twoKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + if tree.GetHeight() != 1 { + t.Fatal("invalid height") + } + // Verify that the value is updated + if !bytes.Equal(tree.(*StemNode).Values[1], twoKey[:]) { + t.Fatal("invalid height") + } +} +func TestLargeNumberOfEntries(t *testing.T) { + var err error + tree := NewBinaryNode() + for i := range 256 { + var key [32]byte + key[0] = byte(i) + tree, err = tree.Insert(key[:], ffKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + } + height := tree.GetHeight() + if height != 1+8 { + t.Fatalf("invalid height, wanted %d, got %d", 1+8, height) + } +} + +func TestMerkleizeMultipleEntries(t *testing.T) { + var err error + tree := NewBinaryNode() + keys := [][]byte{ + zeroKey[:], + common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(), + common.HexToHash("0100000000000000000000000000000000000000000000000000000000000000").Bytes(), + common.HexToHash("8100000000000000000000000000000000000000000000000000000000000000").Bytes(), + } + for i, key := range keys { + var v [32]byte + binary.LittleEndian.PutUint64(v[:8], uint64(i)) + tree, err = tree.Insert(key, v[:], nil, 0) + if err != nil { + t.Fatal(err) + } + } + got := tree.Hash() + expected := common.HexToHash("9317155862f7a3867660ddd0966ff799a3d16aa4df1e70a7516eaa4a675191b5") + if got != expected { + t.Fatalf("invalid root, expected=%x, got = %x", expected, got) + } +} diff --git a/trie/committer.go b/trie/committer.go index a040868c6c8..2a2142e0ffa 100644 --- a/trie/committer.go +++ b/trie/committer.go @@ -29,12 +29,12 @@ import ( // insertion order. type committer struct { nodes *trienode.NodeSet - tracer *prevalueTracer + tracer *PrevalueTracer collectLeaf bool } // newCommitter creates a new committer or picks one from the pool. -func newCommitter(nodeset *trienode.NodeSet, tracer *prevalueTracer, collectLeaf bool) *committer { +func newCommitter(nodeset *trienode.NodeSet, tracer *PrevalueTracer, collectLeaf bool) *committer { return &committer{ nodes: nodeset, tracer: tracer, @@ -142,7 +142,7 @@ func (c *committer) store(path []byte, n node) node { // The node is embedded in its parent, in other words, this node // will not be stored in the database independently, mark it as // deleted only if the node was existent in database before. - origin := c.tracer.get(path) + origin := c.tracer.Get(path) if len(origin) != 0 { c.nodes.AddNode(path, trienode.NewDeletedWithPrev(origin)) } @@ -150,7 +150,7 @@ func (c *committer) store(path []byte, n node) node { } // Collect the dirty node to nodeset for return. nhash := common.BytesToHash(hash) - c.nodes.AddNode(path, trienode.NewNodeWithPrev(nhash, nodeToBytes(n), c.tracer.get(path))) + c.nodes.AddNode(path, trienode.NewNodeWithPrev(nhash, nodeToBytes(n), c.tracer.Get(path))) // Collect the corresponding leaf node if it's required. We don't check // full node since it's impossible to store value in fullNode. The key diff --git a/trie/iterator.go b/trie/iterator.go index e6fedf24309..80298ce48f1 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -405,7 +405,7 @@ func (it *nodeIterator) resolveHash(hash hashNode, path []byte) (node, error) { // loaded blob will be tracked, while it's not required here since // all loaded nodes won't be linked to trie at all and track nodes // may lead to out-of-memory issue. - blob, err := it.trie.reader.node(path, common.BytesToHash(hash)) + blob, err := it.trie.reader.Node(path, common.BytesToHash(hash)) if err != nil { return nil, err } @@ -426,7 +426,7 @@ func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error) // loaded blob will be tracked, while it's not required here since // all loaded nodes won't be linked to trie at all and track nodes // may lead to out-of-memory issue. - return it.trie.reader.node(path, common.BytesToHash(hash)) + return it.trie.reader.Node(path, common.BytesToHash(hash)) } func (st *nodeIteratorState) resolve(it *nodeIterator, path []byte) error { diff --git a/trie/proof.go b/trie/proof.go index f3ed417094d..1a06ed5d5e3 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -69,7 +69,7 @@ func (t *Trie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { // loaded blob will be tracked, while it's not required here since // all loaded nodes won't be linked to trie at all and track nodes // may lead to out-of-memory issue. - blob, err := t.reader.node(prefix, common.BytesToHash(n)) + blob, err := t.reader.Node(prefix, common.BytesToHash(n)) if err != nil { log.Error("Unhandled trie error in Trie.Prove", "err", err) return err @@ -571,7 +571,7 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, keys [][]byte, valu root: root, reader: newEmptyReader(), opTracer: newOpTracer(), - prevalueTracer: newPrevalueTracer(), + prevalueTracer: NewPrevalueTracer(), } if empty { tr.root = nil diff --git a/trie/tracer.go b/trie/tracer.go index b0542404a7c..04122d1384f 100644 --- a/trie/tracer.go +++ b/trie/tracer.go @@ -94,46 +94,45 @@ func (t *opTracer) deletedList() [][]byte { return paths } -// prevalueTracer tracks the original values of resolved trie nodes. Cached trie +// PrevalueTracer tracks the original values of resolved trie nodes. Cached trie // node values are expected to be immutable. A zero-size node value is treated as // non-existent and should not occur in practice. // -// Note prevalueTracer is not thread-safe, callers should be responsible for -// handling the concurrency issues by themselves. -type prevalueTracer struct { +// Note PrevalueTracer is thread-safe. +type PrevalueTracer struct { data map[string][]byte lock sync.RWMutex } -// newPrevalueTracer initializes the tracer for capturing resolved trie nodes. -func newPrevalueTracer() *prevalueTracer { - return &prevalueTracer{ +// NewPrevalueTracer initializes the tracer for capturing resolved trie nodes. +func NewPrevalueTracer() *PrevalueTracer { + return &PrevalueTracer{ data: make(map[string][]byte), } } -// put tracks the newly loaded trie node and caches its RLP-encoded +// Put tracks the newly loaded trie node and caches its RLP-encoded // blob internally. Do not modify the value outside this function, // as it is not deep-copied. -func (t *prevalueTracer) put(path []byte, val []byte) { +func (t *PrevalueTracer) Put(path []byte, val []byte) { t.lock.Lock() defer t.lock.Unlock() t.data[string(path)] = val } -// get returns the cached trie node value. If the node is not found, nil will +// Get returns the cached trie node value. If the node is not found, nil will // be returned. -func (t *prevalueTracer) get(path []byte) []byte { +func (t *PrevalueTracer) Get(path []byte) []byte { t.lock.RLock() defer t.lock.RUnlock() return t.data[string(path)] } -// hasList returns a list of flags indicating whether the corresponding trie nodes +// HasList returns a list of flags indicating whether the corresponding trie nodes // specified by the path exist in the trie. -func (t *prevalueTracer) hasList(list [][]byte) []bool { +func (t *PrevalueTracer) HasList(list [][]byte) []bool { t.lock.RLock() defer t.lock.RUnlock() @@ -145,29 +144,29 @@ func (t *prevalueTracer) hasList(list [][]byte) []bool { return exists } -// values returns a list of values of the cached trie nodes. -func (t *prevalueTracer) values() map[string][]byte { +// Values returns a list of values of the cached trie nodes. +func (t *PrevalueTracer) Values() map[string][]byte { t.lock.RLock() defer t.lock.RUnlock() return maps.Clone(t.data) } -// reset resets the cached content in the prevalueTracer. -func (t *prevalueTracer) reset() { +// Reset resets the cached content in the prevalueTracer. +func (t *PrevalueTracer) Reset() { t.lock.Lock() defer t.lock.Unlock() clear(t.data) } -// copy returns a copied prevalueTracer instance. -func (t *prevalueTracer) copy() *prevalueTracer { +// Copy returns a copied prevalueTracer instance. +func (t *PrevalueTracer) Copy() *PrevalueTracer { t.lock.RLock() defer t.lock.RUnlock() // Shadow clone is used, as the cached trie node values are immutable - return &prevalueTracer{ + return &PrevalueTracer{ data: maps.Clone(t.data), } } diff --git a/trie/trie.go b/trie/trie.go index 98cf751f477..630462f8ca5 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -55,11 +55,11 @@ type Trie struct { uncommitted int // reader is the handler trie can retrieve nodes from. - reader *trieReader + reader *Reader // Various tracers for capturing the modifications to trie opTracer *opTracer - prevalueTracer *prevalueTracer + prevalueTracer *PrevalueTracer } // newFlag returns the cache flag value for a newly created node. @@ -77,7 +77,7 @@ func (t *Trie) Copy() *Trie { uncommitted: t.uncommitted, reader: t.reader, opTracer: t.opTracer.copy(), - prevalueTracer: t.prevalueTracer.copy(), + prevalueTracer: t.prevalueTracer.Copy(), } } @@ -88,7 +88,7 @@ func (t *Trie) Copy() *Trie { // empty, otherwise, the root node must be present in database or returns // a MissingNodeError if not. func New(id *ID, db database.NodeDatabase) (*Trie, error) { - reader, err := newTrieReader(id.StateRoot, id.Owner, db) + reader, err := NewReader(id.StateRoot, id.Owner, db) if err != nil { return nil, err } @@ -96,7 +96,7 @@ func New(id *ID, db database.NodeDatabase) (*Trie, error) { owner: id.Owner, reader: reader, opTracer: newOpTracer(), - prevalueTracer: newPrevalueTracer(), + prevalueTracer: NewPrevalueTracer(), } if id.Root != (common.Hash{}) && id.Root != types.EmptyRootHash { rootnode, err := trie.resolveAndTrack(id.Root[:], nil) @@ -289,7 +289,7 @@ func (t *Trie) getNode(origNode node, path []byte, pos int) (item []byte, newnod if hash == nil { return nil, origNode, 0, errors.New("non-consensus node") } - blob, err := t.reader.node(path, common.BytesToHash(hash)) + blob, err := t.reader.Node(path, common.BytesToHash(hash)) return blob, origNode, 1, err } // Path still needs to be traversed, descend into children @@ -655,11 +655,11 @@ func (t *Trie) resolve(n node, prefix []byte) (node, error) { // node's original value. The rlp-encoded blob is preferred to be loaded from // database because it's easy to decode node while complex to encode node to blob. func (t *Trie) resolveAndTrack(n hashNode, prefix []byte) (node, error) { - blob, err := t.reader.node(prefix, common.BytesToHash(n)) + blob, err := t.reader.Node(prefix, common.BytesToHash(n)) if err != nil { return nil, err } - t.prevalueTracer.put(prefix, blob) + t.prevalueTracer.Put(prefix, blob) // The returned node blob won't be changed afterward. No need to // deep-copy the slice. @@ -673,7 +673,7 @@ func (t *Trie) deletedNodes() [][]byte { var ( pos int list = t.opTracer.deletedList() - flags = t.prevalueTracer.hasList(list) + flags = t.prevalueTracer.HasList(list) ) for i := 0; i < len(list); i++ { if flags[i] { @@ -711,7 +711,7 @@ func (t *Trie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet) { } nodes := trienode.NewNodeSet(t.owner) for _, path := range paths { - nodes.AddNode(path, trienode.NewDeletedWithPrev(t.prevalueTracer.get(path))) + nodes.AddNode(path, trienode.NewDeletedWithPrev(t.prevalueTracer.Get(path))) } return types.EmptyRootHash, nodes // case (b) } @@ -729,7 +729,7 @@ func (t *Trie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet) { } nodes := trienode.NewNodeSet(t.owner) for _, path := range t.deletedNodes() { - nodes.AddNode(path, trienode.NewDeletedWithPrev(t.prevalueTracer.get(path))) + nodes.AddNode(path, trienode.NewDeletedWithPrev(t.prevalueTracer.Get(path))) } // If the number of changes is below 100, we let one thread handle it t.root = newCommitter(nodes, t.prevalueTracer, collectLeaf).Commit(t.root, t.uncommitted > 100) @@ -753,7 +753,7 @@ func (t *Trie) hashRoot() []byte { // Witness returns a set containing all trie nodes that have been accessed. func (t *Trie) Witness() map[string][]byte { - return t.prevalueTracer.values() + return t.prevalueTracer.Values() } // Reset drops the referenced root node and cleans all internal state. @@ -763,6 +763,6 @@ func (t *Trie) Reset() { t.unhashed = 0 t.uncommitted = 0 t.opTracer.reset() - t.prevalueTracer.reset() + t.prevalueTracer.Reset() t.committed = false } diff --git a/trie/trie_reader.go b/trie/trie_reader.go index a42cdb0cf98..42fe4d72c7b 100644 --- a/trie/trie_reader.go +++ b/trie/trie_reader.go @@ -22,39 +22,39 @@ import ( "github.com/ethereum/go-ethereum/triedb/database" ) -// trieReader is a wrapper of the underlying node reader. It's not safe +// Reader is a wrapper of the underlying database reader. It's not safe // for concurrent usage. -type trieReader struct { +type Reader struct { owner common.Hash reader database.NodeReader banned map[string]struct{} // Marker to prevent node from being accessed, for tests } -// newTrieReader initializes the trie reader with the given node reader. -func newTrieReader(stateRoot, owner common.Hash, db database.NodeDatabase) (*trieReader, error) { +// NewReader initializes the trie reader with the given database reader. +func NewReader(stateRoot, owner common.Hash, db database.NodeDatabase) (*Reader, error) { if stateRoot == (common.Hash{}) || stateRoot == types.EmptyRootHash { - return &trieReader{owner: owner}, nil + return &Reader{owner: owner}, nil } reader, err := db.NodeReader(stateRoot) if err != nil { return nil, &MissingNodeError{Owner: owner, NodeHash: stateRoot, err: err} } - return &trieReader{owner: owner, reader: reader}, nil + return &Reader{owner: owner, reader: reader}, nil } // newEmptyReader initializes the pure in-memory reader. All read operations // should be forbidden and returns the MissingNodeError. -func newEmptyReader() *trieReader { - return &trieReader{} +func newEmptyReader() *Reader { + return &Reader{} } -// node retrieves the rlp-encoded trie node with the provided trie node +// Node retrieves the rlp-encoded trie node with the provided trie node // information. An MissingNodeError will be returned in case the node is // not found or any error is encountered. // // Don't modify the returned byte slice since it's not deep-copied and // still be referenced by database. -func (r *trieReader) node(path []byte, hash common.Hash) ([]byte, error) { +func (r *Reader) Node(path []byte, hash common.Hash) ([]byte, error) { // Perform the logics in tests for preventing trie node access. if r.banned != nil { if _, ok := r.banned[string(path)]; ok { diff --git a/trie/verkle.go b/trie/verkle.go index e00ea21602c..186ac1f642b 100644 --- a/trie/verkle.go +++ b/trie/verkle.go @@ -41,13 +41,13 @@ var ( type VerkleTrie struct { root verkle.VerkleNode cache *utils.PointCache - reader *trieReader - tracer *prevalueTracer + reader *Reader + tracer *PrevalueTracer } // NewVerkleTrie constructs a verkle tree based on the specified root hash. func NewVerkleTrie(root common.Hash, db database.NodeDatabase, cache *utils.PointCache) (*VerkleTrie, error) { - reader, err := newTrieReader(root, common.Hash{}, db) + reader, err := NewReader(root, common.Hash{}, db) if err != nil { return nil, err } @@ -55,7 +55,7 @@ func NewVerkleTrie(root common.Hash, db database.NodeDatabase, cache *utils.Poin root: verkle.New(), cache: cache, reader: reader, - tracer: newPrevalueTracer(), + tracer: NewPrevalueTracer(), } // Parse the root verkle node if it's not empty. if root != types.EmptyVerkleHash && root != types.EmptyRootHash { @@ -289,7 +289,7 @@ func (t *VerkleTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) { nodeset := trienode.NewNodeSet(common.Hash{}) for _, node := range nodes { // Hash parameter is not used in pathdb - nodeset.AddNode(node.Path, trienode.NewNodeWithPrev(common.Hash{}, node.SerializedBytes, t.tracer.get(node.Path))) + nodeset.AddNode(node.Path, trienode.NewNodeWithPrev(common.Hash{}, node.SerializedBytes, t.tracer.Get(node.Path))) } // Serialize root commitment form return t.Hash(), nodeset @@ -322,7 +322,7 @@ func (t *VerkleTrie) Copy() *VerkleTrie { root: t.root.Copy(), cache: t.cache, reader: t.reader, - tracer: t.tracer.copy(), + tracer: t.tracer.Copy(), } } @@ -443,11 +443,11 @@ func (t *VerkleTrie) ToDot() string { } func (t *VerkleTrie) nodeResolver(path []byte) ([]byte, error) { - blob, err := t.reader.node(path, common.Hash{}) + blob, err := t.reader.Node(path, common.Hash{}) if err != nil { return nil, err } - t.tracer.put(path, blob) + t.tracer.Put(path, blob) return blob, nil }