Skip to content

Commit 6cfe5b0

Browse files
committed
Add MultiRaft sharding manager
1 parent 723b1ba commit 6cfe5b0

File tree

2 files changed

+188
-0
lines changed

2 files changed

+188
-0
lines changed

kv/sharded_transaction.go

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
package kv
2+
3+
import (
4+
"sync"
5+
6+
"github.com/bootjp/elastickv/distribution"
7+
pb "github.com/bootjp/elastickv/proto"
8+
"github.com/cockroachdb/errors"
9+
)
10+
11+
// ShardedTransactionManager routes requests to multiple raft groups based on key ranges.
12+
type ShardedTransactionManager struct {
13+
engine *distribution.Engine
14+
mu sync.RWMutex
15+
groups map[uint64]Transactional
16+
}
17+
18+
// NewShardedTransactionManager creates a new manager.
19+
func NewShardedTransactionManager(e *distribution.Engine) *ShardedTransactionManager {
20+
return &ShardedTransactionManager{
21+
engine: e,
22+
groups: make(map[uint64]Transactional),
23+
}
24+
}
25+
26+
// Register associates a raft group ID with a Transactional.
27+
func (s *ShardedTransactionManager) Register(group uint64, tm Transactional) {
28+
s.mu.Lock()
29+
defer s.mu.Unlock()
30+
s.groups[group] = tm
31+
}
32+
33+
// Commit dispatches requests to the correct raft group.
34+
func (s *ShardedTransactionManager) Commit(reqs []*pb.Request) (*TransactionResponse, error) {
35+
grouped, err := s.groupRequests(reqs)
36+
if err != nil {
37+
return nil, errors.WithStack(err)
38+
}
39+
40+
var max uint64
41+
for gid, rs := range grouped {
42+
tm, ok := s.getGroup(gid)
43+
if !ok {
44+
return nil, errors.Wrapf(ErrInvalidRequest, "unknown group %d", gid)
45+
}
46+
r, err := tm.Commit(rs)
47+
if err != nil {
48+
return nil, errors.WithStack(err)
49+
}
50+
if r.CommitIndex > max {
51+
max = r.CommitIndex
52+
}
53+
}
54+
return &TransactionResponse{CommitIndex: max}, nil
55+
}
56+
57+
// Abort dispatches aborts to the correct raft group.
58+
func (s *ShardedTransactionManager) Abort(reqs []*pb.Request) (*TransactionResponse, error) {
59+
grouped, err := s.groupRequests(reqs)
60+
if err != nil {
61+
return nil, errors.WithStack(err)
62+
}
63+
64+
var max uint64
65+
for gid, rs := range grouped {
66+
tm, ok := s.getGroup(gid)
67+
if !ok {
68+
return nil, errors.Wrapf(ErrInvalidRequest, "unknown group %d", gid)
69+
}
70+
r, err := tm.Abort(rs)
71+
if err != nil {
72+
return nil, errors.WithStack(err)
73+
}
74+
if r.CommitIndex > max {
75+
max = r.CommitIndex
76+
}
77+
}
78+
return &TransactionResponse{CommitIndex: max}, nil
79+
}
80+
81+
func (s *ShardedTransactionManager) getGroup(id uint64) (Transactional, bool) {
82+
s.mu.RLock()
83+
defer s.mu.RUnlock()
84+
tm, ok := s.groups[id]
85+
return tm, ok
86+
}
87+
88+
func (s *ShardedTransactionManager) groupRequests(reqs []*pb.Request) (map[uint64][]*pb.Request, error) {
89+
batches := make(map[uint64][]*pb.Request)
90+
for _, r := range reqs {
91+
if len(r.Mutations) == 0 {
92+
return nil, ErrInvalidRequest
93+
}
94+
key := r.Mutations[0].Key
95+
route, ok := s.engine.GetRoute(key)
96+
if !ok {
97+
return nil, errors.Wrapf(ErrInvalidRequest, "no route for key %q", key)
98+
}
99+
batches[route.GroupID] = append(batches[route.GroupID], r)
100+
}
101+
return batches, nil
102+
}
103+
104+
var _ Transactional = (*ShardedTransactionManager)(nil)

kv/sharded_transaction_test.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package kv
2+
3+
import (
4+
"context"
5+
"testing"
6+
"time"
7+
8+
"github.com/bootjp/elastickv/distribution"
9+
pb "github.com/bootjp/elastickv/proto"
10+
"github.com/bootjp/elastickv/store"
11+
"github.com/hashicorp/raft"
12+
)
13+
14+
// helper to create single-node raft
15+
func newTestRaft(t *testing.T, id string, fsm raft.FSM) *raft.Raft {
16+
t.Helper()
17+
c := raft.DefaultConfig()
18+
c.LocalID = raft.ServerID(id)
19+
ldb := raft.NewInmemStore()
20+
sdb := raft.NewInmemStore()
21+
fss := raft.NewInmemSnapshotStore()
22+
addr, trans := raft.NewInmemTransport(raft.ServerAddress(id))
23+
r, err := raft.NewRaft(c, fsm, ldb, sdb, fss, trans)
24+
if err != nil {
25+
t.Fatalf("new raft: %v", err)
26+
}
27+
cfg := raft.Configuration{Servers: []raft.Server{{ID: raft.ServerID(id), Address: addr}}}
28+
if err := r.BootstrapCluster(cfg).Error(); err != nil {
29+
t.Fatalf("bootstrap: %v", err)
30+
}
31+
32+
// single node should eventually become leader
33+
for i := 0; i < 100; i++ {
34+
if r.State() == raft.Leader {
35+
break
36+
}
37+
time.Sleep(50 * time.Millisecond)
38+
}
39+
if r.State() != raft.Leader {
40+
t.Fatalf("node %s is not leader", id)
41+
}
42+
return r
43+
}
44+
45+
func TestShardedTransactionManagerCommit(t *testing.T) {
46+
e := distribution.NewEngine()
47+
e.UpdateRoute([]byte("a"), []byte("m"), 1)
48+
e.UpdateRoute([]byte("m"), nil, 2)
49+
50+
stm := NewShardedTransactionManager(e)
51+
52+
// group 1
53+
s1 := store.NewRbMemoryStore()
54+
l1 := store.NewRbMemoryStoreWithExpire(time.Minute)
55+
r1 := newTestRaft(t, "1", NewKvFSM(s1, l1))
56+
defer r1.Shutdown()
57+
stm.Register(1, NewTransaction(r1))
58+
59+
// group 2
60+
s2 := store.NewRbMemoryStore()
61+
l2 := store.NewRbMemoryStoreWithExpire(time.Minute)
62+
r2 := newTestRaft(t, "2", NewKvFSM(s2, l2))
63+
defer r2.Shutdown()
64+
stm.Register(2, NewTransaction(r2))
65+
66+
reqs := []*pb.Request{
67+
{IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("b"), Value: []byte("v1")}}},
68+
{IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("x"), Value: []byte("v2")}}},
69+
}
70+
71+
_, err := stm.Commit(reqs)
72+
if err != nil {
73+
t.Fatalf("commit: %v", err)
74+
}
75+
76+
v, err := s1.Get(context.Background(), []byte("b"))
77+
if err != nil || string(v) != "v1" {
78+
t.Fatalf("group1 value: %v %v", v, err)
79+
}
80+
v, err = s2.Get(context.Background(), []byte("x"))
81+
if err != nil || string(v) != "v2" {
82+
t.Fatalf("group2 value: %v %v", v, err)
83+
}
84+
}

0 commit comments

Comments
 (0)