diff --git a/kv/shard_router.go b/kv/shard_router.go new file mode 100644 index 0000000..556860d --- /dev/null +++ b/kv/shard_router.go @@ -0,0 +1,118 @@ +package kv + +import ( + "context" + "sync" + + "github.com/bootjp/elastickv/distribution" + pb "github.com/bootjp/elastickv/proto" + "github.com/bootjp/elastickv/store" + "github.com/cockroachdb/errors" +) + +// ShardRouter routes requests to multiple raft groups based on key ranges. +// It does not provide transactional guarantees across shards; commits are executed +// per shard and failures may leave partial results. +type ShardRouter struct { + engine *distribution.Engine + mu sync.RWMutex + groups map[uint64]*routerGroup +} + +type routerGroup struct { + tm Transactional + store store.Store +} + +// NewShardRouter creates a new router. +func NewShardRouter(e *distribution.Engine) *ShardRouter { + return &ShardRouter{ + engine: e, + groups: make(map[uint64]*routerGroup), + } +} + +// Register associates a raft group ID with its transactional manager and store. +func (s *ShardRouter) Register(group uint64, tm Transactional, st store.Store) { + s.mu.Lock() + defer s.mu.Unlock() + s.groups[group] = &routerGroup{tm: tm, store: st} +} + +func (s *ShardRouter) Commit(reqs []*pb.Request) (*TransactionResponse, error) { + return s.process(reqs, func(g *routerGroup, rs []*pb.Request) (*TransactionResponse, error) { + return g.tm.Commit(rs) + }) +} + +// Abort dispatches aborts to the correct raft group. +func (s *ShardRouter) Abort(reqs []*pb.Request) (*TransactionResponse, error) { + return s.process(reqs, func(g *routerGroup, rs []*pb.Request) (*TransactionResponse, error) { + return g.tm.Abort(rs) + }) +} + +func (s *ShardRouter) process(reqs []*pb.Request, fn func(*routerGroup, []*pb.Request) (*TransactionResponse, error)) (*TransactionResponse, error) { + grouped, err := s.groupRequests(reqs) + if err != nil { + return nil, errors.WithStack(err) + } + + var max uint64 + for gid, rs := range grouped { + g, ok := s.getGroup(gid) + if !ok { + return nil, errors.Wrapf(ErrInvalidRequest, "unknown group %d", gid) + } + r, err := fn(g, rs) + if err != nil { + return nil, errors.WithStack(err) + } + if r.CommitIndex > max { + max = r.CommitIndex + } + } + return &TransactionResponse{CommitIndex: max}, nil +} + +func (s *ShardRouter) getGroup(id uint64) (*routerGroup, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + g, ok := s.groups[id] + return g, ok +} + +func (s *ShardRouter) groupRequests(reqs []*pb.Request) (map[uint64][]*pb.Request, error) { + batches := make(map[uint64][]*pb.Request) + for _, r := range reqs { + if len(r.Mutations) == 0 { + return nil, ErrInvalidRequest + } + key := r.Mutations[0].Key + route, ok := s.engine.GetRoute(key) + if !ok { + return nil, errors.Wrapf(ErrInvalidRequest, "no route for key %q", key) + } + batches[route.GroupID] = append(batches[route.GroupID], r) + } + return batches, nil +} + +// Get retrieves a key routed to the correct shard. +func (s *ShardRouter) Get(ctx context.Context, key []byte) ([]byte, error) { + route, ok := s.engine.GetRoute(key) + if !ok { + return nil, errors.Wrapf(ErrInvalidRequest, "no route for key %q", key) + } + g, ok := s.getGroup(route.GroupID) + if !ok { + return nil, errors.Wrapf(ErrInvalidRequest, "unknown group %d", route.GroupID) + } + v, err := g.store.Get(ctx, key) + if err != nil { + return nil, errors.WithStack(err) + } + return v, nil +} + +var _ Transactional = (*ShardRouter)(nil) diff --git a/kv/shard_router_test.go b/kv/shard_router_test.go new file mode 100644 index 0000000..3f87850 --- /dev/null +++ b/kv/shard_router_test.go @@ -0,0 +1,255 @@ +package kv + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/bootjp/elastickv/distribution" + pb "github.com/bootjp/elastickv/proto" + "github.com/bootjp/elastickv/store" + "github.com/hashicorp/raft" +) + +// helper to create a multi-node raft cluster and return the leader +func newTestRaft(t *testing.T, id string, fsm raft.FSM) (*raft.Raft, func()) { + t.Helper() + + const n = 3 + addrs := make([]raft.ServerAddress, n) + trans := make([]*raft.InmemTransport, n) + for i := 0; i < n; i++ { + addr, tr := raft.NewInmemTransport(raft.ServerAddress(fmt.Sprintf("%s-%d", id, i))) + addrs[i] = addr + trans[i] = tr + } + // fully connect transports + for i := 0; i < n; i++ { + for j := i + 1; j < n; j++ { + trans[i].Connect(addrs[j], trans[j]) + trans[j].Connect(addrs[i], trans[i]) + } + } + + // cluster configuration + cfg := raft.Configuration{} + for i := 0; i < n; i++ { + cfg.Servers = append(cfg.Servers, raft.Server{ + ID: raft.ServerID(fmt.Sprintf("%s-%d", id, i)), + Address: addrs[i], + }) + } + + rafts := make([]*raft.Raft, n) + for i := 0; i < n; i++ { + c := raft.DefaultConfig() + c.LocalID = cfg.Servers[i].ID + if i == 0 { + c.HeartbeatTimeout = 200 * time.Millisecond + c.ElectionTimeout = 400 * time.Millisecond + c.LeaderLeaseTimeout = 100 * time.Millisecond + } else { + c.HeartbeatTimeout = 1 * time.Second + c.ElectionTimeout = 2 * time.Second + c.LeaderLeaseTimeout = 500 * time.Millisecond + } + ldb := raft.NewInmemStore() + sdb := raft.NewInmemStore() + fss := raft.NewInmemSnapshotStore() + var rfsm raft.FSM + if i == 0 { + rfsm = fsm + } else { + rfsm = NewKvFSM(store.NewRbMemoryStore(), store.NewRbMemoryStoreWithExpire(time.Minute)) + } + r, err := raft.NewRaft(c, rfsm, ldb, sdb, fss, trans[i]) + if err != nil { + t.Fatalf("new raft %d: %v", i, err) + } + if err := r.BootstrapCluster(cfg).Error(); err != nil { + t.Fatalf("bootstrap %d: %v", i, err) + } + rafts[i] = r + } + + // node 0 should become leader + for i := 0; i < 100; i++ { + if rafts[0].State() == raft.Leader { + break + } + time.Sleep(50 * time.Millisecond) + } + if rafts[0].State() != raft.Leader { + t.Fatalf("node %s-0 is not leader", id) + } + + shutdown := func() { + for _, r := range rafts { + r.Shutdown() + } + } + return rafts[0], shutdown +} + +func TestShardRouterCommit(t *testing.T) { + ctx := context.Background() + + e := distribution.NewEngine() + e.UpdateRoute([]byte("a"), []byte("m"), 1) + e.UpdateRoute([]byte("m"), nil, 2) + + router := NewShardRouter(e) + + // group 1 + s1 := store.NewRbMemoryStore() + l1 := store.NewRbMemoryStoreWithExpire(time.Minute) + r1, stop1 := newTestRaft(t, "1", NewKvFSM(s1, l1)) + defer stop1() + router.Register(1, NewTransaction(r1), s1) + + // group 2 + s2 := store.NewRbMemoryStore() + l2 := store.NewRbMemoryStoreWithExpire(time.Minute) + r2, stop2 := newTestRaft(t, "2", NewKvFSM(s2, l2)) + defer stop2() + router.Register(2, NewTransaction(r2), s2) + + reqs := []*pb.Request{ + {IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("b"), Value: []byte("v1")}}}, + {IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("x"), Value: []byte("v2")}}}, + } + + if _, err := router.Commit(reqs); err != nil { + t.Fatalf("commit: %v", err) + } + + v, err := router.Get(ctx, []byte("b")) + if err != nil || string(v) != "v1" { + t.Fatalf("group1 get: %v %v", v, err) + } + v, err = router.Get(ctx, []byte("x")) + if err != nil || string(v) != "v2" { + t.Fatalf("group2 get: %v %v", v, err) + } +} + +func TestShardRouterSplitAndMerge(t *testing.T) { + ctx := context.Background() + + e := distribution.NewEngine() + // start with single shard handled by group 1 + e.UpdateRoute([]byte("a"), nil, 1) + + router := NewShardRouter(e) + + // group 1 + s1 := store.NewRbMemoryStore() + l1 := store.NewRbMemoryStoreWithExpire(time.Minute) + r1, stop1 := newTestRaft(t, "1", NewKvFSM(s1, l1)) + defer stop1() + router.Register(1, NewTransaction(r1), s1) + + // group 2 (will be used after split) + s2 := store.NewRbMemoryStore() + l2 := store.NewRbMemoryStoreWithExpire(time.Minute) + r2, stop2 := newTestRaft(t, "2", NewKvFSM(s2, l2)) + defer stop2() + router.Register(2, NewTransaction(r2), s2) + + // initial write routed to group 1 + req := []*pb.Request{ + {IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("b"), Value: []byte("v1")}}}, + } + if _, err := router.Commit(req); err != nil { + t.Fatalf("commit group1: %v", err) + } + v, err := router.Get(ctx, []byte("b")) + if err != nil || string(v) != "v1" { + t.Fatalf("group1 value before split: %v %v", v, err) + } + + // split shard: group1 handles [a,m), group2 handles [m,∞) + e2 := distribution.NewEngine() + e2.UpdateRoute([]byte("a"), []byte("m"), 1) + e2.UpdateRoute([]byte("m"), nil, 2) + router.engine = e2 + + // write routed to group 2 after split + req = []*pb.Request{ + {IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("x"), Value: []byte("v2")}}}, + } + if _, err := router.Commit(req); err != nil { + t.Fatalf("commit group2: %v", err) + } + v, err = router.Get(ctx, []byte("x")) + if err != nil || string(v) != "v2" { + t.Fatalf("group2 value after split: %v %v", v, err) + } + + // merge shards back: all keys handled by group1 + e3 := distribution.NewEngine() + e3.UpdateRoute([]byte("a"), nil, 1) + router.engine = e3 + + // write routed to group1 after merge + req = []*pb.Request{ + {IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("z"), Value: []byte("v3")}}}, + } + if _, err := router.Commit(req); err != nil { + t.Fatalf("commit after merge: %v", err) + } + v, err = router.Get(ctx, []byte("z")) + if err != nil || string(v) != "v3" { + t.Fatalf("group1 value after merge: %v %v", v, err) + } +} + +type fakeTM struct { + commitErr bool + commitCalls int + abortCalls int +} + +func (f *fakeTM) Commit(reqs []*pb.Request) (*TransactionResponse, error) { + f.commitCalls++ + if f.commitErr { + return nil, fmt.Errorf("commit fail") + } + return &TransactionResponse{}, nil +} + +func (f *fakeTM) Abort(reqs []*pb.Request) (*TransactionResponse, error) { + f.abortCalls++ + return &TransactionResponse{}, nil +} + +func TestShardRouterCommitFailure(t *testing.T) { + e := distribution.NewEngine() + e.UpdateRoute([]byte("a"), []byte("m"), 1) + e.UpdateRoute([]byte("m"), nil, 2) + + router := NewShardRouter(e) + + ok := &fakeTM{} + fail := &fakeTM{commitErr: true} + router.Register(1, ok, nil) + router.Register(2, fail, nil) + + reqs := []*pb.Request{ + {IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("b"), Value: []byte("v1")}}}, + {IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("x"), Value: []byte("v2")}}}, + } + + if _, err := router.Commit(reqs); err == nil { + t.Fatalf("expected error") + } + + if fail.commitCalls == 0 || ok.commitCalls == 0 { + t.Fatalf("expected commits on both groups") + } + + if ok.abortCalls != 0 { + t.Fatalf("unexpected abort on successful group") + } +}