Skip to content

Commit cb1d5db

Browse files
committed
Refactor snapshot checksum handling logic
1 parent 49860b6 commit cb1d5db

File tree

6 files changed

+73
-33
lines changed

6 files changed

+73
-33
lines changed

adapter/test_util.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ type portsAdress struct {
5656

5757
const (
5858
// raft and the grpc requested by the client use grpc and are received on the same port
59-
grpcPort = 50000
60-
raftPort = 50000
61-
redisPort = 63790
59+
grpcPort = 50000
60+
raftPort = 50000
61+
redisPort = 63790
6262
dynamoPort = 28000
63-
63+
6464
// followers wait longer before starting elections to give the leader time to bootstrap and share config.
6565
followerElectionTimeout = 10 * time.Second
6666
)

kv/shard_router_test.go

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,32 +17,58 @@ func newTestRaft(t *testing.T, id string, fsm raft.FSM) (*raft.Raft, func()) {
1717
t.Helper()
1818

1919
const n = 3
20+
addrs, trans := setupInmemTransports(id, n)
21+
connectInmemTransports(addrs, trans)
22+
cfg := buildRaftConfig(id, addrs)
23+
rafts := initTestRafts(t, cfg, trans, fsm)
24+
waitForLeader(t, id, rafts[0])
25+
26+
shutdown := func() {
27+
for _, r := range rafts {
28+
r.Shutdown()
29+
}
30+
}
31+
return rafts[0], shutdown
32+
}
33+
34+
func setupInmemTransports(id string, n int) ([]raft.ServerAddress, []*raft.InmemTransport) {
2035
addrs := make([]raft.ServerAddress, n)
2136
trans := make([]*raft.InmemTransport, n)
2237
for i := 0; i < n; i++ {
2338
addr, tr := raft.NewInmemTransport(raft.ServerAddress(fmt.Sprintf("%s-%d", id, i)))
2439
addrs[i] = addr
2540
trans[i] = tr
2641
}
42+
return addrs, trans
43+
}
44+
45+
func connectInmemTransports(addrs []raft.ServerAddress, trans []*raft.InmemTransport) {
2746
// fully connect transports
28-
for i := 0; i < n; i++ {
29-
for j := i + 1; j < n; j++ {
47+
for i := 0; i < len(trans); i++ {
48+
for j := i + 1; j < len(trans); j++ {
3049
trans[i].Connect(addrs[j], trans[j])
3150
trans[j].Connect(addrs[i], trans[i])
3251
}
3352
}
53+
}
3454

55+
func buildRaftConfig(id string, addrs []raft.ServerAddress) raft.Configuration {
3556
// cluster configuration
3657
cfg := raft.Configuration{}
37-
for i := 0; i < n; i++ {
58+
for i := 0; i < len(addrs); i++ {
3859
cfg.Servers = append(cfg.Servers, raft.Server{
3960
ID: raft.ServerID(fmt.Sprintf("%s-%d", id, i)),
4061
Address: addrs[i],
4162
})
4263
}
64+
return cfg
65+
}
4366

44-
rafts := make([]*raft.Raft, n)
45-
for i := 0; i < n; i++ {
67+
func initTestRafts(t *testing.T, cfg raft.Configuration, trans []*raft.InmemTransport, fsm raft.FSM) []*raft.Raft {
68+
t.Helper()
69+
70+
rafts := make([]*raft.Raft, len(trans))
71+
for i := 0; i < len(trans); i++ {
4672
c := raft.DefaultConfig()
4773
c.LocalID = cfg.Servers[i].ID
4874
if i == 0 {
@@ -73,23 +99,22 @@ func newTestRaft(t *testing.T, id string, fsm raft.FSM) (*raft.Raft, func()) {
7399
rafts[i] = r
74100
}
75101

76-
// node 0 should become leader
102+
return rafts
103+
}
104+
105+
func waitForLeader(t *testing.T, id string, leader *raft.Raft) {
106+
t.Helper()
107+
108+
// node 0 should become leader quickly during tests
77109
for i := 0; i < 100; i++ {
78-
if rafts[0].State() == raft.Leader {
110+
if leader.State() == raft.Leader {
79111
break
80112
}
81113
time.Sleep(50 * time.Millisecond)
82114
}
83-
if rafts[0].State() != raft.Leader {
115+
if leader.State() != raft.Leader {
84116
t.Fatalf("node %s-0 is not leader", id)
85117
}
86-
87-
shutdown := func() {
88-
for _, r := range rafts {
89-
r.Shutdown()
90-
}
91-
}
92-
return rafts[0], shutdown
93118
}
94119

95120
func TestShardRouterCommit(t *testing.T) {

store/memory_store.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ type memoryStore struct {
2929

3030
const (
3131
defaultExpireInterval = 30 * time.Second
32+
checksumSize = 4
3233
)
3334

3435
func NewMemoryStore() Store {
@@ -230,11 +231,11 @@ func (s *memoryStore) Restore(r io.Reader) error {
230231
if err != nil {
231232
return errors.WithStack(err)
232233
}
233-
if len(data) < 4 {
234+
if len(data) < checksumSize {
234235
return errors.WithStack(ErrInvalidChecksum)
235236
}
236-
payload := data[:len(data)-4]
237-
expected := binary.LittleEndian.Uint32(data[len(data)-4:])
237+
payload := data[:len(data)-checksumSize]
238+
expected := binary.LittleEndian.Uint32(data[len(data)-checksumSize:])
238239
if crc32.ChecksumIEEE(payload) != expected {
239240
return errors.WithStack(ErrInvalidChecksum)
240241
}

store/memory_store_test.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package store
33
import (
44
"bytes"
55
"context"
6+
"io"
67
"strconv"
78
"sync"
89
"testing"
@@ -217,8 +218,11 @@ func TestMemoryStore_SnapshotChecksum(t *testing.T) {
217218
buf, err := st.Snapshot()
218219
assert.NoError(t, err)
219220

221+
snapshotData, err := io.ReadAll(buf)
222+
assert.NoError(t, err)
223+
220224
st2 := NewMemoryStore()
221-
err = st2.Restore(bytes.NewReader(buf.(*bytes.Buffer).Bytes()))
225+
err = st2.Restore(bytes.NewReader(snapshotData))
222226
assert.NoError(t, err)
223227

224228
v, err := st2.Get(ctx, []byte("foo"))
@@ -234,9 +238,12 @@ func TestMemoryStore_SnapshotChecksum(t *testing.T) {
234238

235239
buf, err := st.Snapshot()
236240
assert.NoError(t, err)
237-
data := buf.(*bytes.Buffer).Bytes()
238-
corrupted := make([]byte, len(data))
239-
copy(corrupted, data)
241+
242+
snapshotData, err := io.ReadAll(buf)
243+
assert.NoError(t, err)
244+
245+
corrupted := make([]byte, len(snapshotData))
246+
copy(corrupted, snapshotData)
240247
corrupted[0] ^= 0xff
241248

242249
st2 := NewMemoryStore()

store/rb_memory_store.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,11 +250,11 @@ func (s *rbMemoryStore) Restore(r io.Reader) error {
250250
if err != nil {
251251
return errors.WithStack(err)
252252
}
253-
if len(data) < 4 {
253+
if len(data) < checksumSize {
254254
return errors.WithStack(ErrInvalidChecksum)
255255
}
256-
payload := data[:len(data)-4]
257-
expected := binary.LittleEndian.Uint32(data[len(data)-4:])
256+
payload := data[:len(data)-checksumSize]
257+
expected := binary.LittleEndian.Uint32(data[len(data)-checksumSize:])
258258
if crc32.ChecksumIEEE(payload) != expected {
259259
return errors.WithStack(ErrInvalidChecksum)
260260
}

store/rb_memory_store_test.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"context"
66
"encoding/binary"
7+
"io"
78
"strconv"
89
"sync"
910
"testing"
@@ -218,8 +219,11 @@ func TestRbMemoryStore_SnapshotChecksum(t *testing.T) {
218219
buf, err := st.Snapshot()
219220
assert.NoError(t, err)
220221

222+
snapshotData, err := io.ReadAll(buf)
223+
assert.NoError(t, err)
224+
221225
st2 := NewRbMemoryStore()
222-
err = st2.Restore(bytes.NewReader(buf.(*bytes.Buffer).Bytes()))
226+
err = st2.Restore(bytes.NewReader(snapshotData))
223227
assert.NoError(t, err)
224228

225229
v, err := st2.Get(ctx, []byte("foo"))
@@ -235,9 +239,12 @@ func TestRbMemoryStore_SnapshotChecksum(t *testing.T) {
235239

236240
buf, err := st.Snapshot()
237241
assert.NoError(t, err)
238-
data := buf.(*bytes.Buffer).Bytes()
239-
corrupted := make([]byte, len(data))
240-
copy(corrupted, data)
242+
243+
snapshotData, err := io.ReadAll(buf)
244+
assert.NoError(t, err)
245+
246+
corrupted := make([]byte, len(snapshotData))
247+
copy(corrupted, snapshotData)
241248
corrupted[0] ^= 0xff
242249

243250
st2 := NewRbMemoryStore()

0 commit comments

Comments
 (0)