diff --git a/rlp/raw.go b/rlp/raw.go index cec90346a10..022f5ab074f 100644 --- a/rlp/raw.go +++ b/rlp/raw.go @@ -17,6 +17,7 @@ package rlp import ( + "fmt" "io" "reflect" ) @@ -152,6 +153,35 @@ func CountValues(b []byte) (int, error) { return i, nil } +// SplitListValues extracts the raw elements from the list RLP-encoding blob. +func SplitListValues(b []byte) ([][]byte, error) { + b, _, err := SplitList(b) + if err != nil { + return nil, fmt.Errorf("decode error: %v", err) + } + var elements [][]byte + for len(b) > 0 { + _, tagsize, size, err := readKind(b) + if err != nil { + return nil, err + } + elements = append(elements, b[:tagsize+size]) + b = b[tagsize+size:] + } + return elements, nil +} + +// MergeListValues takes a list of raw elements and rlp-encodes them as list. +func MergeListValues(elems [][]byte) ([]byte, error) { + w := NewEncoderBuffer(nil) + offset := w.List() + for _, elem := range elems { + w.Write(elem) + } + w.ListEnd(offset) + return w.ToBytes(), nil +} + func readKind(buf []byte) (k Kind, tagsize, contentsize uint64, err error) { if len(buf) == 0 { return 0, 0, 0, io.ErrUnexpectedEOF diff --git a/trie/node.go b/trie/node.go index 74fac4fd4ea..3f14f07d635 100644 --- a/trie/node.go +++ b/trie/node.go @@ -17,6 +17,7 @@ package trie import ( + "bytes" "fmt" "io" "strings" @@ -242,6 +243,74 @@ func decodeRef(buf []byte) (node, []byte, error) { } } +// decodeNodeElements parses the RLP encoding of a trie node and returns all the +// elements in raw byte format. +// +// For full node, it returns a slice of 17 elements; +// For short node, it returns a slice of 2 elements; +func decodeNodeElements(buf []byte) ([][]byte, error) { + if len(buf) == 0 { + return nil, io.ErrUnexpectedEOF + } + return rlp.SplitListValues(buf) +} + +// encodeNodeElements encodes the provided node elements into a rlp list. +func encodeNodeElements(elements [][]byte) ([]byte, error) { + if len(elements) != 2 && len(elements) != 17 { + return nil, fmt.Errorf("invalid number of elements: %d", len(elements)) + } + return rlp.MergeListValues(elements) +} + +// NodeDifference accepts two RLP-encoding nodes and figures out the difference +// between them. +// +// An error is returned if any of the provided blob is nil, or the type of nodes +// are different. +func NodeDifference(oldvalue []byte, newvalue []byte) (int, []int, [][]byte, error) { + oldElems, err := decodeNodeElements(oldvalue) + if err != nil { + return 0, nil, nil, err + } + newElems, err := decodeNodeElements(newvalue) + if err != nil { + return 0, nil, nil, err + } + if len(oldElems) != len(newElems) { + return 0, nil, nil, fmt.Errorf("different node type, old elements: %d, new elements: %d", len(oldElems), len(newElems)) + } + var ( + indices = make([]int, 0, len(oldElems)) + diff = make([][]byte, 0, len(oldElems)) + ) + for i := 0; i < len(oldElems); i++ { + if !bytes.Equal(oldElems[i], newElems[i]) { + indices = append(indices, i) + diff = append(diff, oldElems[i]) + } + } + return len(oldElems), indices, diff, nil +} + +// ReassembleNode accepts a RLP-encoding node along with a set of mutations, +// applying the modification diffs according to the indices and re-assemble. +func ReassembleNode(blob []byte, mutations [][][]byte, indices [][]int) ([]byte, error) { + if len(mutations) == 0 && len(indices) == 0 { + return blob, nil + } + elements, err := decodeNodeElements(blob) + if err != nil { + return nil, err + } + for i := 0; i < len(mutations); i++ { + for j, pos := range indices[i] { + elements[pos] = mutations[i][j] + } + } + return encodeNodeElements(elements) +} + // wraps a decoding error with information about the path to the // invalid child node (for debugging encoding issues). type decodeError struct { diff --git a/trie/node_test.go b/trie/node_test.go index 9b8b33748fa..875f6e38dc7 100644 --- a/trie/node_test.go +++ b/trie/node_test.go @@ -18,9 +18,12 @@ package trie import ( "bytes" + "math/rand" + "reflect" "testing" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/internal/testrand" "github.com/ethereum/go-ethereum/rlp" ) @@ -94,6 +97,286 @@ func TestDecodeFullNode(t *testing.T) { } } +func makeTestLeafNode(small bool) []byte { + l := leafNodeEncoder{} + l.Key = hexToCompact(keybytesToHex(testrand.Bytes(10))) + if small { + l.Val = testrand.Bytes(10) + } else { + l.Val = testrand.Bytes(32) + } + buf := rlp.NewEncoderBuffer(nil) + l.encode(buf) + return buf.ToBytes() +} + +func makeTestFullNode(small bool) []byte { + n := fullnodeEncoder{} + for i := 0; i < 16; i++ { + switch rand.Intn(3) { + case 0: + // write nil + case 1: + // write hash + n.Children[i] = testrand.Bytes(32) + case 2: + // write embedded node + n.Children[i] = makeTestLeafNode(small) + } + } + n.Children[16] = testrand.Bytes(32) // value + buf := rlp.NewEncoderBuffer(nil) + n.encode(buf) + return buf.ToBytes() +} + +func TestEncodeDecodeNodeElements(t *testing.T) { + var nodes [][]byte + nodes = append(nodes, makeTestFullNode(true)) + nodes = append(nodes, makeTestFullNode(false)) + nodes = append(nodes, makeTestLeafNode(true)) + nodes = append(nodes, makeTestLeafNode(false)) + + for _, blob := range nodes { + elements, err := decodeNodeElements(blob) + if err != nil { + t.Fatalf("Failed to decode node elements: %v", err) + } + enc, err := encodeNodeElements(elements) + if err != nil { + t.Fatalf("Failed to encode node elements: %v", err) + } + if !bytes.Equal(enc, blob) { + t.Fatalf("Unexpected encoded node element, want: %v, got: %v", blob, enc) + } + } +} + +func makeTestLeafNodePair() ([]byte, []byte, [][]byte, []int) { + var ( + na = leafNodeEncoder{} + nb = leafNodeEncoder{} + ) + key := keybytesToHex(testrand.Bytes(10)) + na.Key = hexToCompact(key) + nb.Key = hexToCompact(key) + + valA := testrand.Bytes(32) + valB := testrand.Bytes(32) + na.Val = valA + nb.Val = valB + + bufa, bufb := rlp.NewEncoderBuffer(nil), rlp.NewEncoderBuffer(nil) + na.encode(bufa) + nb.encode(bufb) + diff, _ := rlp.EncodeToBytes(valA) + return bufa.ToBytes(), bufb.ToBytes(), [][]byte{diff}, []int{1} +} + +func makeTestFullNodePair() ([]byte, []byte, [][]byte, []int) { + var ( + na = fullnodeEncoder{} + nb = fullnodeEncoder{} + indices []int + values [][]byte + ) + for i := 0; i < 16; i++ { + switch rand.Intn(3) { + case 0: + // write nil + case 1: + // write same + var child []byte + if rand.Intn(2) == 0 { + child = testrand.Bytes(32) // hashnode + } else { + child = makeTestLeafNode(true) // embedded node + } + na.Children[i] = child + nb.Children[i] = child + case 2: + // write different + var ( + va []byte + diff []byte + ) + rnd := rand.Intn(3) + if rnd == 0 { + va = testrand.Bytes(32) // hashnode + diff, _ = rlp.EncodeToBytes(va) + } else if rnd == 1 { + va = makeTestLeafNode(true) // embedded node + diff = va + } else { + va = nil + diff = rlp.EmptyString + } + vb := testrand.Bytes(32) // hashnode + na.Children[i] = va + nb.Children[i] = vb + + indices = append(indices, i) + values = append(values, diff) + } + } + na.Children[16] = nil + nb.Children[16] = nil + + bufa, bufb := rlp.NewEncoderBuffer(nil), rlp.NewEncoderBuffer(nil) + na.encode(bufa) + nb.encode(bufb) + return bufa.ToBytes(), bufb.ToBytes(), values, indices +} + +func TestNodeDifference(t *testing.T) { + type testsuite struct { + old []byte + new []byte + expErr bool + expIndices []int + expValues [][]byte + } + var tests = []testsuite{ + // Invalid node data + { + old: nil, new: nil, expErr: true, + }, + { + old: testrand.Bytes(32), new: nil, expErr: true, + }, + { + old: nil, new: testrand.Bytes(32), expErr: true, + }, + { + old: testrand.Bytes(32), new: testrand.Bytes(32), expErr: true, + }, + // Different node type + { + old: makeTestLeafNode(true), new: makeTestFullNode(true), expErr: true, + }, + } + for range 10 { + va, vb, elements, indices := makeTestLeafNodePair() + tests = append(tests, testsuite{ + old: va, + new: vb, + expErr: false, + expIndices: indices, + expValues: elements, + }) + } + for range 10 { + va, vb, elements, indices := makeTestFullNodePair() + tests = append(tests, testsuite{ + old: va, + new: vb, + expErr: false, + expIndices: indices, + expValues: elements, + }) + } + + for _, test := range tests { + _, indices, values, err := NodeDifference(test.old, test.new) + if test.expErr && err == nil { + t.Fatal("Expect error, got nil") + } + if !test.expErr && err != nil { + t.Fatalf("Unexpect error, %v", err) + } + if err == nil { + if !reflect.DeepEqual(indices, test.expIndices) { + t.Fatalf("Unexpected indices, want: %v, got: %v", test.expIndices, indices) + } + if !reflect.DeepEqual(values, test.expValues) { + t.Fatalf("Unexpected values, want: %v, got: %v", test.expValues, values) + } + } + } +} + +func TestReassembleFullNode(t *testing.T) { + var fn fullnodeEncoder + for i := 0; i < 16; i++ { + if rand.Intn(2) == 0 { + fn.Children[i] = testrand.Bytes(32) + } + } + buf := rlp.NewEncoderBuffer(nil) + fn.encode(buf) + enc := buf.ToBytes() + + // Generate a list of diffs + var ( + values [][][]byte + indices [][]int + ) + for i := 0; i < 10; i++ { + var ( + pos = make(map[int]struct{}) + poslist []int + valuelist [][]byte + ) + for j := 0; j < 3; j++ { + p := rand.Intn(16) + if _, ok := pos[p]; ok { + continue + } + pos[p] = struct{}{} + + nh := testrand.Bytes(32) + diff, _ := rlp.EncodeToBytes(nh) + poslist = append(poslist, p) + valuelist = append(valuelist, diff) + fn.Children[p] = nh + } + values = append(values, valuelist) + indices = append(indices, poslist) + } + reassembled, err := ReassembleNode(enc, values, indices) + if err != nil { + t.Fatalf("Failed to re-assemble full node %v", err) + } + buf2 := rlp.NewEncoderBuffer(nil) + fn.encode(buf2) + enc2 := buf2.ToBytes() + if !reflect.DeepEqual(enc2, reassembled) { + t.Fatalf("Unexpeted reassembled node") + } +} + +func TestReassembleShortNode(t *testing.T) { + var ln leafNodeEncoder + ln.Key = hexToCompact(keybytesToHex(testrand.Bytes(10))) + ln.Val = testrand.Bytes(10) + buf := rlp.NewEncoderBuffer(nil) + ln.encode(buf) + enc := buf.ToBytes() + + // Generate a list of diffs + var ( + values [][][]byte + indices [][]int + ) + for i := 0; i < 10; i++ { + val := testrand.Bytes(10) + ln.Val = val + diff, _ := rlp.EncodeToBytes(val) + values = append(values, [][]byte{diff}) + indices = append(indices, []int{1}) + } + reassembled, err := ReassembleNode(enc, values, indices) + if err != nil { + t.Fatalf("Failed to re-assemble full node %v", err) + } + buf2 := rlp.NewEncoderBuffer(nil) + ln.encode(buf2) + enc2 := buf2.ToBytes() + if !reflect.DeepEqual(enc2, reassembled) { + t.Fatalf("Unexpeted reassembled node") + } +} + // goos: darwin // goarch: arm64 // pkg: github.com/ethereum/go-ethereum/trie diff --git a/triedb/pathdb/nodes.go b/triedb/pathdb/nodes.go index c6f9e7aece3..b2720c36dfd 100644 --- a/triedb/pathdb/nodes.go +++ b/triedb/pathdb/nodes.go @@ -14,12 +14,14 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . +// nolint:unused package pathdb import ( "bytes" "errors" "fmt" + "hash/fnv" "io" "maps" @@ -30,6 +32,7 @@ import ( "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie/trienode" ) @@ -424,3 +427,237 @@ func (s *nodeSetWithOrigin) decode(r *rlp.Stream) error { s.computeSize() return nil } + +// encodeNodeCompressed encodes the trie node differences between two consecutive +// versions into byte stream. The format is as below: +// +// - metadata byte layout (1 byte): +// +// ┌──── Bits (from MSB to LSB) ───┐ +// │ 7 │ 6 │ 5 │ 4 │ 3 │ 2 │ 1 │ 0 │ +// └───────────────────────────────┘ +// │ │ │ │ │ │ │ └─ FlagA: marks if value is encoded in compressed format (1 always) +// │ │ │ │ │ │ └───── FlagB: marks if no extended bitmap used after the metadata byte +// │ │ │ │ │ └───────── FlagC: bitmap for node (only used when flagB == 1) +// │ │ │ │ └───────────── FlagD: bitmap for node (only used when flagB == 1) +// │ │ │ └───────────────── FlagE: reserved +// │ │ └───────────────────── FlagF: reserved +// │ └───────────────────────── FlagG: reserved +// └───────────────────────────── FlagH: reserved +// +// Example: +// +// 0b_0000_1011 +// +// Bit0=1, Bit1=1 -> node in compressed format, no extended bitmap +// Bit2=0, Bit3=1 -> the key of a short node is not stored; its value is stored. +// +// - 2 bytes extended bitmap (full node only), each bit represents a corresponding child; +// - concatenation of original value of modified children along with its size +func encodeNodeCompressed(addExtension bool, elements [][]byte, indices []int) []byte { + var ( + enc []byte + flag = byte(1) + ) + if !addExtension { + flag |= 2 + + // Embedded bitmap + for _, pos := range indices { + flag |= 1 << (pos + 2) + } + enc = append(enc, flag) + } else { + // Extended bitmap + bitmap := make([]byte, 2) // bitmaps for at most 16 children + for _, pos := range indices { + // Children[16] is only theoretically possible in the Merkle-Patricia-trie, + // in practice this field is never used in the Ethereum case. + if pos == 16 { + panic(fmt.Sprintf("Unexpected node children index %d", pos)) + } + bitIndex := uint(pos % 8) + bitmap[pos/8] |= 1 << bitIndex + } + enc = append(enc, flag) + enc = append(enc, bitmap...) + } + for _, element := range elements { + enc = append(enc, byte(len(element))) // 1 byte is sufficient for element size + enc = append(enc, element...) + } + return enc +} + +// encodeNodeFull encodes the full trie node value into byte stream. The format is +// as below: +// +// - metadata byte layout (1 byte): 0b0 +// - node value +func encodeNodeFull(value []byte) []byte { + enc := make([]byte, len(value)+1) + copy(enc[1:], value) + return enc +} + +// decodeNodeCompressed decodes the byte stream of compressed trie node +// back to the original elements and their indices. +// +// It assumes the byte stream contains a compressed format node. +func decodeNodeCompressed(data []byte) ([][]byte, []int, error) { + if len(data) < 1 { + return nil, nil, errors.New("invalid data: too short") + } + flag := data[0] + if flag&byte(1) == 0 { + return nil, nil, errors.New("invalid data: full node value") + } + noExtend := flag&byte(2) != 0 + + // Reconstruct indices from bitmap + var indices []int + if noExtend { + if flag&byte(4) != 0 { + indices = append(indices, 0) + } + if flag&byte(8) != 0 { + indices = append(indices, 1) + } + data = data[1:] + } else { + if len(data) < 3 { + return nil, nil, errors.New("invalid data: too short") + } + bitmap := data[1:3] + for index, b := range bitmap { + for bitIdx := 0; bitIdx < 8; bitIdx++ { + if b&(1<