Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion adapter/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ func (i *Internal) Forward(_ context.Context, req *pb.ForwardRequest) (*pb.Forwa

r, err := i.transactionManager.Commit(req.Requests)
if err != nil {
return nil, errors.WithStack(err)
return &pb.ForwardResponse{
Success: false,
CommitIndex: 0,
}, errors.WithStack(err)
}

return &pb.ForwardResponse{
Expand Down
67 changes: 57 additions & 10 deletions adapter/test_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/raft"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
Expand Down Expand Up @@ -55,11 +56,13 @@ type portsAdress struct {

const (
// raft and the grpc requested by the client use grpc and are received on the same port
grpcPort = 50000
raftPort = 50000

grpcPort = 50000
raftPort = 50000
redisPort = 63790
dynamoPort = 28000

// followers wait longer before starting elections to give the leader time to bootstrap and share config.
followerElectionTimeout = 10 * time.Second
)

var mu sync.Mutex
Expand Down Expand Up @@ -136,6 +139,7 @@ func createNode(t *testing.T, n int) ([]Node, []string, []string) {
nodes, grpcAdders, redisAdders := setupNodes(t, ctx, n, ports, cfg)

waitForNodeListeners(t, ctx, nodes, waitTimeout, waitInterval)
waitForConfigReplication(t, cfg, nodes, waitTimeout, waitInterval)
waitForRaftReadiness(t, nodes, waitTimeout, waitInterval)

return nodes, grpcAdders, redisAdders
Expand Down Expand Up @@ -163,9 +167,6 @@ func waitForNodeListeners(t *testing.T, ctx context.Context, nodes []Node, waitT

func waitForRaftReadiness(t *testing.T, nodes []Node, waitTimeout, waitInterval time.Duration) {
t.Helper()
assert.Eventually(t, func() bool {
return nodes[0].raft.State() == raft.Leader
}, waitTimeout, waitInterval)

expectedLeader := raft.ServerAddress(nodes[0].raftAddress)
assert.Eventually(t, func() bool {
Expand All @@ -188,6 +189,40 @@ func waitForRaftReadiness(t *testing.T, nodes []Node, waitTimeout, waitInterval
}, waitTimeout, waitInterval)
}

func waitForConfigReplication(t *testing.T, cfg raft.Configuration, nodes []Node, waitTimeout, waitInterval time.Duration) {
t.Helper()

assert.Eventually(t, func() bool {
for _, n := range nodes {
future := n.raft.GetConfiguration()
if future.Error() != nil {
return false
}

current := future.Configuration().Servers
if len(current) != len(cfg.Servers) {
return false
}

for _, expected := range cfg.Servers {
if !containsServer(current, expected) {
return false
}
}
}
return true
}, waitTimeout, waitInterval)
}

func containsServer(servers []raft.Server, expected raft.Server) bool {
for _, s := range servers {
if s.ID == expected.ID && s.Address == expected.Address && s.Suffrage == expected.Suffrage {
return true
}
}
return false
}

func assignPorts(n int) []portsAdress {
ports := make([]portsAdress, n)
for i := 0; i < n; i++ {
Expand All @@ -214,6 +249,8 @@ func buildRaftConfig(n int, ports []portsAdress) raft.Configuration {
return cfg
}

const leaderElectionTimeout = 0 * time.Second

func setupNodes(t *testing.T, ctx context.Context, n int, ports []portsAdress, cfg raft.Configuration) ([]Node, []string, []string) {
t.Helper()
var grpcAdders []string
Expand All @@ -228,7 +265,13 @@ func setupNodes(t *testing.T, ctx context.Context, n int, ports []portsAdress, c

port := ports[i]

r, tm, err := newRaft(strconv.Itoa(i), port.raftAddress, fsm, i == 0, cfg)
// リーダーが先に投票を開始させる
electionTimeout := leaderElectionTimeout
if i != 0 {
electionTimeout = followerElectionTimeout
}

r, tm, err := newRaft(strconv.Itoa(i), port.raftAddress, fsm, i == 0, cfg, electionTimeout)
assert.NoError(t, err)

s := grpc.NewServer()
Expand All @@ -244,7 +287,7 @@ func setupNodes(t *testing.T, ctx context.Context, n int, ports []portsAdress, c
raftadmin.Register(s, r)

grpcSock, err := lc.Listen(ctx, "tcp", port.grpcAddress)
assert.NoError(t, err)
require.NoError(t, err)

grpcAdders = append(grpcAdders, port.grpcAddress)
redisAdders = append(redisAdders, port.redisAddress)
Expand All @@ -253,7 +296,7 @@ func setupNodes(t *testing.T, ctx context.Context, n int, ports []portsAdress, c
}(s, grpcSock)

l, err := lc.Listen(ctx, "tcp", port.redisAddress)
assert.NoError(t, err)
require.NoError(t, err)
rd := NewRedisServer(l, st, coordinator)
go func(server *RedisServer) {
assert.NoError(t, server.Run())
Expand Down Expand Up @@ -282,10 +325,14 @@ func setupNodes(t *testing.T, ctx context.Context, n int, ports []portsAdress, c
return nodes, grpcAdders, redisAdders
}

func newRaft(myID string, myAddress string, fsm raft.FSM, bootstrap bool, cfg raft.Configuration) (*raft.Raft, *transport.Manager, error) {
func newRaft(myID string, myAddress string, fsm raft.FSM, bootstrap bool, cfg raft.Configuration, electionTimeout time.Duration) (*raft.Raft, *transport.Manager, error) {
c := raft.DefaultConfig()
c.LocalID = raft.ServerID(myID)

if electionTimeout > 0 {
c.ElectionTimeout = electionTimeout
}

// this config is for development
ldb := raft.NewInmemStore()
sdb := raft.NewInmemStore()
Expand Down
4 changes: 4 additions & 0 deletions kv/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ func (c *Coordinate) redirect(reqs *OperationGroup[OP]) (*CoordinateResponse, er
return nil, errors.WithStack(err)
}

if !r.Success {
return nil, ErrInvalidRequest
}

return &CoordinateResponse{
CommitIndex: r.CommitIndex,
}, nil
Expand Down
55 changes: 40 additions & 15 deletions kv/shard_router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,58 @@ func newTestRaft(t *testing.T, id string, fsm raft.FSM) (*raft.Raft, func()) {
t.Helper()

const n = 3
addrs, trans := setupInmemTransports(id, n)
connectInmemTransports(addrs, trans)
cfg := buildRaftConfig(id, addrs)
rafts := initTestRafts(t, cfg, trans, fsm)
waitForLeader(t, id, rafts[0])

shutdown := func() {
for _, r := range rafts {
r.Shutdown()
}
}
return rafts[0], shutdown
}

func setupInmemTransports(id string, n int) ([]raft.ServerAddress, []*raft.InmemTransport) {
addrs := make([]raft.ServerAddress, n)
trans := make([]*raft.InmemTransport, n)
for i := 0; i < n; i++ {
addr, tr := raft.NewInmemTransport(raft.ServerAddress(fmt.Sprintf("%s-%d", id, i)))
addrs[i] = addr
trans[i] = tr
}
return addrs, trans
}

func connectInmemTransports(addrs []raft.ServerAddress, trans []*raft.InmemTransport) {
// fully connect transports
for i := 0; i < n; i++ {
for j := i + 1; j < n; j++ {
for i := 0; i < len(trans); i++ {
for j := i + 1; j < len(trans); j++ {
trans[i].Connect(addrs[j], trans[j])
trans[j].Connect(addrs[i], trans[i])
}
}
}

func buildRaftConfig(id string, addrs []raft.ServerAddress) raft.Configuration {
// cluster configuration
cfg := raft.Configuration{}
for i := 0; i < n; i++ {
for i := 0; i < len(addrs); i++ {
cfg.Servers = append(cfg.Servers, raft.Server{
ID: raft.ServerID(fmt.Sprintf("%s-%d", id, i)),
Address: addrs[i],
})
}
return cfg
}

rafts := make([]*raft.Raft, n)
for i := 0; i < n; i++ {
func initTestRafts(t *testing.T, cfg raft.Configuration, trans []*raft.InmemTransport, fsm raft.FSM) []*raft.Raft {
t.Helper()

rafts := make([]*raft.Raft, len(trans))
for i := 0; i < len(trans); i++ {
c := raft.DefaultConfig()
c.LocalID = cfg.Servers[i].ID
if i == 0 {
Expand Down Expand Up @@ -73,23 +99,22 @@ func newTestRaft(t *testing.T, id string, fsm raft.FSM) (*raft.Raft, func()) {
rafts[i] = r
}

// node 0 should become leader
return rafts
}

func waitForLeader(t *testing.T, id string, leader *raft.Raft) {
t.Helper()

// node 0 should become leader quickly during tests
for i := 0; i < 100; i++ {
if rafts[0].State() == raft.Leader {
if leader.State() == raft.Leader {
break
}
time.Sleep(50 * time.Millisecond)
}
if rafts[0].State() != raft.Leader {
if leader.State() != raft.Leader {
t.Fatalf("node %s-0 is not leader", id)
}

shutdown := func() {
for _, r := range rafts {
r.Shutdown()
}
}
return rafts[0], shutdown
}

func TestShardRouterCommit(t *testing.T) {
Expand Down
7 changes: 4 additions & 3 deletions store/memory_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type memoryStore struct {

const (
defaultExpireInterval = 30 * time.Second
checksumSize = 4
)

func NewMemoryStore() Store {
Expand Down Expand Up @@ -230,11 +231,11 @@ func (s *memoryStore) Restore(r io.Reader) error {
if err != nil {
return errors.WithStack(err)
}
if len(data) < 4 {
if len(data) < checksumSize {
return errors.WithStack(ErrInvalidChecksum)
}
payload := data[:len(data)-4]
expected := binary.LittleEndian.Uint32(data[len(data)-4:])
payload := data[:len(data)-checksumSize]
expected := binary.LittleEndian.Uint32(data[len(data)-checksumSize:])
if crc32.ChecksumIEEE(payload) != expected {
return errors.WithStack(ErrInvalidChecksum)
}
Expand Down
15 changes: 11 additions & 4 deletions store/memory_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package store
import (
"bytes"
"context"
"io"
"strconv"
"sync"
"testing"
Expand Down Expand Up @@ -217,8 +218,11 @@ func TestMemoryStore_SnapshotChecksum(t *testing.T) {
buf, err := st.Snapshot()
assert.NoError(t, err)

snapshotData, err := io.ReadAll(buf)
assert.NoError(t, err)

st2 := NewMemoryStore()
err = st2.Restore(bytes.NewReader(buf.(*bytes.Buffer).Bytes()))
err = st2.Restore(bytes.NewReader(snapshotData))
assert.NoError(t, err)

v, err := st2.Get(ctx, []byte("foo"))
Expand All @@ -234,9 +238,12 @@ func TestMemoryStore_SnapshotChecksum(t *testing.T) {

buf, err := st.Snapshot()
assert.NoError(t, err)
data := buf.(*bytes.Buffer).Bytes()
corrupted := make([]byte, len(data))
copy(corrupted, data)

snapshotData, err := io.ReadAll(buf)
assert.NoError(t, err)

corrupted := make([]byte, len(snapshotData))
copy(corrupted, snapshotData)
corrupted[0] ^= 0xff

st2 := NewMemoryStore()
Expand Down
6 changes: 3 additions & 3 deletions store/rb_memory_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,11 @@ func (s *rbMemoryStore) Restore(r io.Reader) error {
if err != nil {
return errors.WithStack(err)
}
if len(data) < 4 {
if len(data) < checksumSize {
return errors.WithStack(ErrInvalidChecksum)
}
payload := data[:len(data)-4]
expected := binary.LittleEndian.Uint32(data[len(data)-4:])
payload := data[:len(data)-checksumSize]
expected := binary.LittleEndian.Uint32(data[len(data)-checksumSize:])
if crc32.ChecksumIEEE(payload) != expected {
return errors.WithStack(ErrInvalidChecksum)
}
Expand Down
15 changes: 11 additions & 4 deletions store/rb_memory_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/binary"
"io"
"strconv"
"sync"
"testing"
Expand Down Expand Up @@ -218,8 +219,11 @@ func TestRbMemoryStore_SnapshotChecksum(t *testing.T) {
buf, err := st.Snapshot()
assert.NoError(t, err)

snapshotData, err := io.ReadAll(buf)
assert.NoError(t, err)

st2 := NewRbMemoryStore()
err = st2.Restore(bytes.NewReader(buf.(*bytes.Buffer).Bytes()))
err = st2.Restore(bytes.NewReader(snapshotData))
assert.NoError(t, err)

v, err := st2.Get(ctx, []byte("foo"))
Expand All @@ -235,9 +239,12 @@ func TestRbMemoryStore_SnapshotChecksum(t *testing.T) {

buf, err := st.Snapshot()
assert.NoError(t, err)
data := buf.(*bytes.Buffer).Bytes()
corrupted := make([]byte, len(data))
copy(corrupted, data)

snapshotData, err := io.ReadAll(buf)
assert.NoError(t, err)

corrupted := make([]byte, len(snapshotData))
copy(corrupted, snapshotData)
corrupted[0] ^= 0xff

st2 := NewRbMemoryStore()
Expand Down
Loading