Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions kv/shard_router.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package kv

import (
"context"
"sync"

"github.com/bootjp/elastickv/distribution"
pb "github.com/bootjp/elastickv/proto"
"github.com/bootjp/elastickv/store"
"github.com/cockroachdb/errors"
)

// ShardRouter routes requests to multiple raft groups based on key ranges.
// It does not provide transactional guarantees across shards; commits are executed
// per shard and failures may leave partial results.
type ShardRouter struct {
engine *distribution.Engine
mu sync.RWMutex
groups map[uint64]*routerGroup
}

type routerGroup struct {
tm Transactional
store store.Store
}

// NewShardRouter creates a new router.
func NewShardRouter(e *distribution.Engine) *ShardRouter {
return &ShardRouter{
engine: e,
groups: make(map[uint64]*routerGroup),
}
}

// Register associates a raft group ID with its transactional manager and store.
func (s *ShardRouter) Register(group uint64, tm Transactional, st store.Store) {
s.mu.Lock()
defer s.mu.Unlock()
s.groups[group] = &routerGroup{tm: tm, store: st}
}

func (s *ShardRouter) Commit(reqs []*pb.Request) (*TransactionResponse, error) {
return s.process(reqs, func(g *routerGroup, rs []*pb.Request) (*TransactionResponse, error) {
return g.tm.Commit(rs)
})
}

// Abort dispatches aborts to the correct raft group.
func (s *ShardRouter) Abort(reqs []*pb.Request) (*TransactionResponse, error) {
return s.process(reqs, func(g *routerGroup, rs []*pb.Request) (*TransactionResponse, error) {
return g.tm.Abort(rs)
})
}

func (s *ShardRouter) process(reqs []*pb.Request, fn func(*routerGroup, []*pb.Request) (*TransactionResponse, error)) (*TransactionResponse, error) {
grouped, err := s.groupRequests(reqs)
if err != nil {
return nil, errors.WithStack(err)
}

var max uint64
for gid, rs := range grouped {
g, ok := s.getGroup(gid)
if !ok {
return nil, errors.Wrapf(ErrInvalidRequest, "unknown group %d", gid)
}
r, err := fn(g, rs)
if err != nil {
return nil, errors.WithStack(err)
}
if r.CommitIndex > max {
max = r.CommitIndex
}
}
return &TransactionResponse{CommitIndex: max}, nil
}

func (s *ShardRouter) getGroup(id uint64) (*routerGroup, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
g, ok := s.groups[id]
return g, ok
}

func (s *ShardRouter) groupRequests(reqs []*pb.Request) (map[uint64][]*pb.Request, error) {
batches := make(map[uint64][]*pb.Request)
for _, r := range reqs {
if len(r.Mutations) == 0 {
return nil, ErrInvalidRequest
}
key := r.Mutations[0].Key
route, ok := s.engine.GetRoute(key)
if !ok {
return nil, errors.Wrapf(ErrInvalidRequest, "no route for key %q", key)
}
batches[route.GroupID] = append(batches[route.GroupID], r)
}
return batches, nil
}

// Get retrieves a key routed to the correct shard.
func (s *ShardRouter) Get(ctx context.Context, key []byte) ([]byte, error) {
route, ok := s.engine.GetRoute(key)
if !ok {
return nil, errors.Wrapf(ErrInvalidRequest, "no route for key %q", key)
}
g, ok := s.getGroup(route.GroupID)
if !ok {
return nil, errors.Wrapf(ErrInvalidRequest, "unknown group %d", route.GroupID)
}
v, err := g.store.Get(ctx, key)
if err != nil {
return nil, errors.WithStack(err)
}
return v, nil
}

var _ Transactional = (*ShardRouter)(nil)
255 changes: 255 additions & 0 deletions kv/shard_router_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
package kv

import (
"context"
"fmt"
"testing"
"time"

"github.com/bootjp/elastickv/distribution"
pb "github.com/bootjp/elastickv/proto"
"github.com/bootjp/elastickv/store"
"github.com/hashicorp/raft"
)

// helper to create a multi-node raft cluster and return the leader
func newTestRaft(t *testing.T, id string, fsm raft.FSM) (*raft.Raft, func()) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚫 [golangci] reported by reviewdog 🐶
calculated cyclomatic complexity for function newTestRaft is 14, max is 10 (cyclop)

t.Helper()

const n = 3
addrs := make([]raft.ServerAddress, n)
trans := make([]*raft.InmemTransport, n)
for i := 0; i < n; i++ {
addr, tr := raft.NewInmemTransport(raft.ServerAddress(fmt.Sprintf("%s-%d", id, i)))
addrs[i] = addr
trans[i] = tr
}
// fully connect transports
for i := 0; i < n; i++ {
for j := i + 1; j < n; j++ {
trans[i].Connect(addrs[j], trans[j])
trans[j].Connect(addrs[i], trans[i])
}
}

// cluster configuration
cfg := raft.Configuration{}
for i := 0; i < n; i++ {
cfg.Servers = append(cfg.Servers, raft.Server{
ID: raft.ServerID(fmt.Sprintf("%s-%d", id, i)),
Address: addrs[i],
})
}

rafts := make([]*raft.Raft, n)
for i := 0; i < n; i++ {
c := raft.DefaultConfig()
c.LocalID = cfg.Servers[i].ID
if i == 0 {
c.HeartbeatTimeout = 200 * time.Millisecond
c.ElectionTimeout = 400 * time.Millisecond
c.LeaderLeaseTimeout = 100 * time.Millisecond
} else {
c.HeartbeatTimeout = 1 * time.Second
c.ElectionTimeout = 2 * time.Second
c.LeaderLeaseTimeout = 500 * time.Millisecond
}
ldb := raft.NewInmemStore()
sdb := raft.NewInmemStore()
fss := raft.NewInmemSnapshotStore()
var rfsm raft.FSM
if i == 0 {
rfsm = fsm
} else {
rfsm = NewKvFSM(store.NewRbMemoryStore(), store.NewRbMemoryStoreWithExpire(time.Minute))
}
r, err := raft.NewRaft(c, rfsm, ldb, sdb, fss, trans[i])
if err != nil {
t.Fatalf("new raft %d: %v", i, err)
}
if err := r.BootstrapCluster(cfg).Error(); err != nil {
t.Fatalf("bootstrap %d: %v", i, err)
}
rafts[i] = r
}

// node 0 should become leader
for i := 0; i < 100; i++ {
if rafts[0].State() == raft.Leader {
break
}
time.Sleep(50 * time.Millisecond)
}
if rafts[0].State() != raft.Leader {
t.Fatalf("node %s-0 is not leader", id)
}

shutdown := func() {
for _, r := range rafts {
r.Shutdown()
}
}
return rafts[0], shutdown
}

func TestShardRouterCommit(t *testing.T) {
ctx := context.Background()

e := distribution.NewEngine()
e.UpdateRoute([]byte("a"), []byte("m"), 1)
e.UpdateRoute([]byte("m"), nil, 2)

router := NewShardRouter(e)

// group 1
s1 := store.NewRbMemoryStore()
l1 := store.NewRbMemoryStoreWithExpire(time.Minute)
r1, stop1 := newTestRaft(t, "1", NewKvFSM(s1, l1))
defer stop1()
router.Register(1, NewTransaction(r1), s1)

// group 2
s2 := store.NewRbMemoryStore()
l2 := store.NewRbMemoryStoreWithExpire(time.Minute)
r2, stop2 := newTestRaft(t, "2", NewKvFSM(s2, l2))
defer stop2()
router.Register(2, NewTransaction(r2), s2)

reqs := []*pb.Request{
{IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("b"), Value: []byte("v1")}}},
{IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("x"), Value: []byte("v2")}}},
}

if _, err := router.Commit(reqs); err != nil {
t.Fatalf("commit: %v", err)
}

v, err := router.Get(ctx, []byte("b"))
if err != nil || string(v) != "v1" {
t.Fatalf("group1 get: %v %v", v, err)
}
v, err = router.Get(ctx, []byte("x"))
if err != nil || string(v) != "v2" {
t.Fatalf("group2 get: %v %v", v, err)
}
}

func TestShardRouterSplitAndMerge(t *testing.T) {
ctx := context.Background()

e := distribution.NewEngine()
// start with single shard handled by group 1
e.UpdateRoute([]byte("a"), nil, 1)

router := NewShardRouter(e)

// group 1
s1 := store.NewRbMemoryStore()
l1 := store.NewRbMemoryStoreWithExpire(time.Minute)
r1, stop1 := newTestRaft(t, "1", NewKvFSM(s1, l1))
defer stop1()
router.Register(1, NewTransaction(r1), s1)

// group 2 (will be used after split)
s2 := store.NewRbMemoryStore()
l2 := store.NewRbMemoryStoreWithExpire(time.Minute)
r2, stop2 := newTestRaft(t, "2", NewKvFSM(s2, l2))
defer stop2()
router.Register(2, NewTransaction(r2), s2)

// initial write routed to group 1
req := []*pb.Request{
{IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("b"), Value: []byte("v1")}}},
}
if _, err := router.Commit(req); err != nil {
t.Fatalf("commit group1: %v", err)
}
v, err := router.Get(ctx, []byte("b"))
if err != nil || string(v) != "v1" {
t.Fatalf("group1 value before split: %v %v", v, err)
}

// split shard: group1 handles [a,m), group2 handles [m,∞)
e2 := distribution.NewEngine()
e2.UpdateRoute([]byte("a"), []byte("m"), 1)
e2.UpdateRoute([]byte("m"), nil, 2)
router.engine = e2

// write routed to group 2 after split
req = []*pb.Request{
{IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("x"), Value: []byte("v2")}}},
}
if _, err := router.Commit(req); err != nil {
t.Fatalf("commit group2: %v", err)
}
v, err = router.Get(ctx, []byte("x"))
if err != nil || string(v) != "v2" {
t.Fatalf("group2 value after split: %v %v", v, err)
}

// merge shards back: all keys handled by group1
e3 := distribution.NewEngine()
e3.UpdateRoute([]byte("a"), nil, 1)
router.engine = e3

// write routed to group1 after merge
req = []*pb.Request{
{IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("z"), Value: []byte("v3")}}},
}
if _, err := router.Commit(req); err != nil {
t.Fatalf("commit after merge: %v", err)
}
v, err = router.Get(ctx, []byte("z"))
if err != nil || string(v) != "v3" {
t.Fatalf("group1 value after merge: %v %v", v, err)
}
}

type fakeTM struct {
commitErr bool
commitCalls int
abortCalls int
}

func (f *fakeTM) Commit(reqs []*pb.Request) (*TransactionResponse, error) {
f.commitCalls++
if f.commitErr {
return nil, fmt.Errorf("commit fail")
}
return &TransactionResponse{}, nil
}

func (f *fakeTM) Abort(reqs []*pb.Request) (*TransactionResponse, error) {
f.abortCalls++
return &TransactionResponse{}, nil
}

func TestShardRouterCommitFailure(t *testing.T) {
e := distribution.NewEngine()
e.UpdateRoute([]byte("a"), []byte("m"), 1)
e.UpdateRoute([]byte("m"), nil, 2)

router := NewShardRouter(e)

ok := &fakeTM{}
fail := &fakeTM{commitErr: true}
router.Register(1, ok, nil)
router.Register(2, fail, nil)

reqs := []*pb.Request{
{IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("b"), Value: []byte("v1")}}},
{IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("x"), Value: []byte("v2")}}},
}

if _, err := router.Commit(reqs); err == nil {
t.Fatalf("expected error")
}

if fail.commitCalls == 0 || ok.commitCalls == 0 {
t.Fatalf("expected commits on both groups")
}

if ok.abortCalls != 0 {
t.Fatalf("unexpected abort on successful group")
}
}
Loading