Skip to content

Commit 8d82b6d

Browse files
committed
route gets through shard router
1 parent b5aadc8 commit 8d82b6d

File tree

2 files changed

+55
-30
lines changed

2 files changed

+55
-30
lines changed

kv/shard_router.go

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
package kv
22

33
import (
4+
"context"
45
"sync"
56

67
"github.com/bootjp/elastickv/distribution"
78
pb "github.com/bootjp/elastickv/proto"
9+
"github.com/bootjp/elastickv/store"
810
"github.com/cockroachdb/errors"
911
)
1012

@@ -14,50 +16,55 @@ import (
1416
type ShardRouter struct {
1517
engine *distribution.Engine
1618
mu sync.RWMutex
17-
groups map[uint64]Transactional
19+
groups map[uint64]*routerGroup
20+
}
21+
22+
type routerGroup struct {
23+
tm Transactional
24+
store store.Store
1825
}
1926

2027
// NewShardRouter creates a new router.
2128
func NewShardRouter(e *distribution.Engine) *ShardRouter {
2229
return &ShardRouter{
2330
engine: e,
24-
groups: make(map[uint64]Transactional),
31+
groups: make(map[uint64]*routerGroup),
2532
}
2633
}
2734

28-
// Register associates a raft group ID with a Transactional.
29-
func (s *ShardRouter) Register(group uint64, tm Transactional) {
35+
// Register associates a raft group ID with its transactional manager and store.
36+
func (s *ShardRouter) Register(group uint64, tm Transactional, st store.Store) {
3037
s.mu.Lock()
3138
defer s.mu.Unlock()
32-
s.groups[group] = tm
39+
s.groups[group] = &routerGroup{tm: tm, store: st}
3340
}
3441

3542
func (s *ShardRouter) Commit(reqs []*pb.Request) (*TransactionResponse, error) {
36-
return s.process(reqs, func(tm Transactional, rs []*pb.Request) (*TransactionResponse, error) {
37-
return tm.Commit(rs)
43+
return s.process(reqs, func(g *routerGroup, rs []*pb.Request) (*TransactionResponse, error) {
44+
return g.tm.Commit(rs)
3845
})
3946
}
4047

4148
// Abort dispatches aborts to the correct raft group.
4249
func (s *ShardRouter) Abort(reqs []*pb.Request) (*TransactionResponse, error) {
43-
return s.process(reqs, func(tm Transactional, rs []*pb.Request) (*TransactionResponse, error) {
44-
return tm.Abort(rs)
50+
return s.process(reqs, func(g *routerGroup, rs []*pb.Request) (*TransactionResponse, error) {
51+
return g.tm.Abort(rs)
4552
})
4653
}
4754

48-
func (s *ShardRouter) process(reqs []*pb.Request, fn func(Transactional, []*pb.Request) (*TransactionResponse, error)) (*TransactionResponse, error) {
55+
func (s *ShardRouter) process(reqs []*pb.Request, fn func(*routerGroup, []*pb.Request) (*TransactionResponse, error)) (*TransactionResponse, error) {
4956
grouped, err := s.groupRequests(reqs)
5057
if err != nil {
5158
return nil, errors.WithStack(err)
5259
}
5360

5461
var max uint64
5562
for gid, rs := range grouped {
56-
tm, ok := s.getGroup(gid)
63+
g, ok := s.getGroup(gid)
5764
if !ok {
5865
return nil, errors.Wrapf(ErrInvalidRequest, "unknown group %d", gid)
5966
}
60-
r, err := fn(tm, rs)
67+
r, err := fn(g, rs)
6168
if err != nil {
6269
return nil, errors.WithStack(err)
6370
}
@@ -68,11 +75,11 @@ func (s *ShardRouter) process(reqs []*pb.Request, fn func(Transactional, []*pb.R
6875
return &TransactionResponse{CommitIndex: max}, nil
6976
}
7077

71-
func (s *ShardRouter) getGroup(id uint64) (Transactional, bool) {
78+
func (s *ShardRouter) getGroup(id uint64) (*routerGroup, bool) {
7279
s.mu.RLock()
7380
defer s.mu.RUnlock()
74-
tm, ok := s.groups[id]
75-
return tm, ok
81+
g, ok := s.groups[id]
82+
return g, ok
7683
}
7784

7885
func (s *ShardRouter) groupRequests(reqs []*pb.Request) (map[uint64][]*pb.Request, error) {
@@ -91,4 +98,21 @@ func (s *ShardRouter) groupRequests(reqs []*pb.Request) (map[uint64][]*pb.Reques
9198
return batches, nil
9299
}
93100

101+
// Get retrieves a key routed to the correct shard.
102+
func (s *ShardRouter) Get(ctx context.Context, key []byte) ([]byte, error) {
103+
route, ok := s.engine.GetRoute(key)
104+
if !ok {
105+
return nil, errors.Wrapf(ErrInvalidRequest, "no route for key %q", key)
106+
}
107+
g, ok := s.getGroup(route.GroupID)
108+
if !ok {
109+
return nil, errors.Wrapf(ErrInvalidRequest, "unknown group %d", route.GroupID)
110+
}
111+
v, err := g.store.Get(ctx, key)
112+
if err != nil {
113+
return nil, errors.WithStack(err)
114+
}
115+
return v, nil
116+
}
117+
94118
var _ Transactional = (*ShardRouter)(nil)

kv/shard_router_test.go

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ func newTestRaft(t *testing.T, id string, fsm raft.FSM) (*raft.Raft, func()) {
9393
}
9494

9595
func TestShardRouterCommit(t *testing.T) {
96+
ctx := context.Background()
97+
9698
e := distribution.NewEngine()
9799
e.UpdateRoute([]byte("a"), []byte("m"), 1)
98100
e.UpdateRoute([]byte("m"), nil, 2)
@@ -104,32 +106,31 @@ func TestShardRouterCommit(t *testing.T) {
104106
l1 := store.NewRbMemoryStoreWithExpire(time.Minute)
105107
r1, stop1 := newTestRaft(t, "1", NewKvFSM(s1, l1))
106108
defer stop1()
107-
router.Register(1, NewTransaction(r1))
109+
router.Register(1, NewTransaction(r1), s1)
108110

109111
// group 2
110112
s2 := store.NewRbMemoryStore()
111113
l2 := store.NewRbMemoryStoreWithExpire(time.Minute)
112114
r2, stop2 := newTestRaft(t, "2", NewKvFSM(s2, l2))
113115
defer stop2()
114-
router.Register(2, NewTransaction(r2))
116+
router.Register(2, NewTransaction(r2), s2)
115117

116118
reqs := []*pb.Request{
117119
{IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("b"), Value: []byte("v1")}}},
118120
{IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("x"), Value: []byte("v2")}}},
119121
}
120122

121-
_, err := router.Commit(reqs)
122-
if err != nil {
123+
if _, err := router.Commit(reqs); err != nil {
123124
t.Fatalf("commit: %v", err)
124125
}
125126

126-
v, err := s1.Get(context.Background(), []byte("b"))
127+
v, err := router.Get(ctx, []byte("b"))
127128
if err != nil || string(v) != "v1" {
128-
t.Fatalf("group1 value: %v %v", v, err)
129+
t.Fatalf("group1 get: %v %v", v, err)
129130
}
130-
v, err = s2.Get(context.Background(), []byte("x"))
131+
v, err = router.Get(ctx, []byte("x"))
131132
if err != nil || string(v) != "v2" {
132-
t.Fatalf("group2 value: %v %v", v, err)
133+
t.Fatalf("group2 get: %v %v", v, err)
133134
}
134135
}
135136

@@ -147,14 +148,14 @@ func TestShardRouterSplitAndMerge(t *testing.T) {
147148
l1 := store.NewRbMemoryStoreWithExpire(time.Minute)
148149
r1, stop1 := newTestRaft(t, "1", NewKvFSM(s1, l1))
149150
defer stop1()
150-
router.Register(1, NewTransaction(r1))
151+
router.Register(1, NewTransaction(r1), s1)
151152

152153
// group 2 (will be used after split)
153154
s2 := store.NewRbMemoryStore()
154155
l2 := store.NewRbMemoryStoreWithExpire(time.Minute)
155156
r2, stop2 := newTestRaft(t, "2", NewKvFSM(s2, l2))
156157
defer stop2()
157-
router.Register(2, NewTransaction(r2))
158+
router.Register(2, NewTransaction(r2), s2)
158159

159160
// initial write routed to group 1
160161
req := []*pb.Request{
@@ -163,7 +164,7 @@ func TestShardRouterSplitAndMerge(t *testing.T) {
163164
if _, err := router.Commit(req); err != nil {
164165
t.Fatalf("commit group1: %v", err)
165166
}
166-
v, err := s1.Get(ctx, []byte("b"))
167+
v, err := router.Get(ctx, []byte("b"))
167168
if err != nil || string(v) != "v1" {
168169
t.Fatalf("group1 value before split: %v %v", v, err)
169170
}
@@ -181,7 +182,7 @@ func TestShardRouterSplitAndMerge(t *testing.T) {
181182
if _, err := router.Commit(req); err != nil {
182183
t.Fatalf("commit group2: %v", err)
183184
}
184-
v, err = s2.Get(ctx, []byte("x"))
185+
v, err = router.Get(ctx, []byte("x"))
185186
if err != nil || string(v) != "v2" {
186187
t.Fatalf("group2 value after split: %v %v", v, err)
187188
}
@@ -198,7 +199,7 @@ func TestShardRouterSplitAndMerge(t *testing.T) {
198199
if _, err := router.Commit(req); err != nil {
199200
t.Fatalf("commit after merge: %v", err)
200201
}
201-
v, err = s1.Get(ctx, []byte("z"))
202+
v, err = router.Get(ctx, []byte("z"))
202203
if err != nil || string(v) != "v3" {
203204
t.Fatalf("group1 value after merge: %v %v", v, err)
204205
}
@@ -232,8 +233,8 @@ func TestShardRouterCommitFailure(t *testing.T) {
232233

233234
ok := &fakeTM{}
234235
fail := &fakeTM{commitErr: true}
235-
router.Register(1, ok)
236-
router.Register(2, fail)
236+
router.Register(1, ok, nil)
237+
router.Register(2, fail, nil)
237238

238239
reqs := []*pb.Request{
239240
{IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("b"), Value: []byte("v1")}}},

0 commit comments

Comments
 (0)