Skip to content

Commit b5aadc8

Browse files
committed
refactor: rename sharded transaction manager
1 parent 7bedcdd commit b5aadc8

File tree

2 files changed

+86
-47
lines changed

2 files changed

+86
-47
lines changed
Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,54 +8,44 @@ import (
88
"github.com/cockroachdb/errors"
99
)
1010

11-
// ShardedTransactionManager routes requests to multiple raft groups based on key ranges.
12-
type ShardedTransactionManager struct {
11+
// ShardRouter routes requests to multiple raft groups based on key ranges.
12+
// It does not provide transactional guarantees across shards; commits are executed
13+
// per shard and failures may leave partial results.
14+
type ShardRouter struct {
1315
engine *distribution.Engine
1416
mu sync.RWMutex
1517
groups map[uint64]Transactional
1618
}
1719

18-
// NewShardedTransactionManager creates a new manager.
19-
func NewShardedTransactionManager(e *distribution.Engine) *ShardedTransactionManager {
20-
return &ShardedTransactionManager{
20+
// NewShardRouter creates a new router.
21+
func NewShardRouter(e *distribution.Engine) *ShardRouter {
22+
return &ShardRouter{
2123
engine: e,
2224
groups: make(map[uint64]Transactional),
2325
}
2426
}
2527

2628
// Register associates a raft group ID with a Transactional.
27-
func (s *ShardedTransactionManager) Register(group uint64, tm Transactional) {
29+
func (s *ShardRouter) Register(group uint64, tm Transactional) {
2830
s.mu.Lock()
2931
defer s.mu.Unlock()
3032
s.groups[group] = tm
3133
}
3234

33-
// Commit dispatches requests to the correct raft group.
34-
func (s *ShardedTransactionManager) Commit(reqs []*pb.Request) (*TransactionResponse, error) {
35-
grouped, err := s.groupRequests(reqs)
36-
if err != nil {
37-
return nil, errors.WithStack(err)
38-
}
39-
40-
var max uint64
41-
for gid, rs := range grouped {
42-
tm, ok := s.getGroup(gid)
43-
if !ok {
44-
return nil, errors.Wrapf(ErrInvalidRequest, "unknown group %d", gid)
45-
}
46-
r, err := tm.Commit(rs)
47-
if err != nil {
48-
return nil, errors.WithStack(err)
49-
}
50-
if r.CommitIndex > max {
51-
max = r.CommitIndex
52-
}
53-
}
54-
return &TransactionResponse{CommitIndex: max}, nil
35+
func (s *ShardRouter) Commit(reqs []*pb.Request) (*TransactionResponse, error) {
36+
return s.process(reqs, func(tm Transactional, rs []*pb.Request) (*TransactionResponse, error) {
37+
return tm.Commit(rs)
38+
})
5539
}
5640

5741
// Abort dispatches aborts to the correct raft group.
58-
func (s *ShardedTransactionManager) Abort(reqs []*pb.Request) (*TransactionResponse, error) {
42+
func (s *ShardRouter) Abort(reqs []*pb.Request) (*TransactionResponse, error) {
43+
return s.process(reqs, func(tm Transactional, rs []*pb.Request) (*TransactionResponse, error) {
44+
return tm.Abort(rs)
45+
})
46+
}
47+
48+
func (s *ShardRouter) process(reqs []*pb.Request, fn func(Transactional, []*pb.Request) (*TransactionResponse, error)) (*TransactionResponse, error) {
5949
grouped, err := s.groupRequests(reqs)
6050
if err != nil {
6151
return nil, errors.WithStack(err)
@@ -67,7 +57,7 @@ func (s *ShardedTransactionManager) Abort(reqs []*pb.Request) (*TransactionRespo
6757
if !ok {
6858
return nil, errors.Wrapf(ErrInvalidRequest, "unknown group %d", gid)
6959
}
70-
r, err := tm.Abort(rs)
60+
r, err := fn(tm, rs)
7161
if err != nil {
7262
return nil, errors.WithStack(err)
7363
}
@@ -78,14 +68,14 @@ func (s *ShardedTransactionManager) Abort(reqs []*pb.Request) (*TransactionRespo
7868
return &TransactionResponse{CommitIndex: max}, nil
7969
}
8070

81-
func (s *ShardedTransactionManager) getGroup(id uint64) (Transactional, bool) {
71+
func (s *ShardRouter) getGroup(id uint64) (Transactional, bool) {
8272
s.mu.RLock()
8373
defer s.mu.RUnlock()
8474
tm, ok := s.groups[id]
8575
return tm, ok
8676
}
8777

88-
func (s *ShardedTransactionManager) groupRequests(reqs []*pb.Request) (map[uint64][]*pb.Request, error) {
78+
func (s *ShardRouter) groupRequests(reqs []*pb.Request) (map[uint64][]*pb.Request, error) {
8979
batches := make(map[uint64][]*pb.Request)
9080
for _, r := range reqs {
9181
if len(r.Mutations) == 0 {
@@ -101,4 +91,4 @@ func (s *ShardedTransactionManager) groupRequests(reqs []*pb.Request) (map[uint6
10191
return batches, nil
10292
}
10393

104-
var _ Transactional = (*ShardedTransactionManager)(nil)
94+
var _ Transactional = (*ShardRouter)(nil)
Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -92,33 +92,33 @@ func newTestRaft(t *testing.T, id string, fsm raft.FSM) (*raft.Raft, func()) {
9292
return rafts[0], shutdown
9393
}
9494

95-
func TestShardedTransactionManagerCommit(t *testing.T) {
95+
func TestShardRouterCommit(t *testing.T) {
9696
e := distribution.NewEngine()
9797
e.UpdateRoute([]byte("a"), []byte("m"), 1)
9898
e.UpdateRoute([]byte("m"), nil, 2)
9999

100-
stm := NewShardedTransactionManager(e)
100+
router := NewShardRouter(e)
101101

102102
// group 1
103103
s1 := store.NewRbMemoryStore()
104104
l1 := store.NewRbMemoryStoreWithExpire(time.Minute)
105105
r1, stop1 := newTestRaft(t, "1", NewKvFSM(s1, l1))
106106
defer stop1()
107-
stm.Register(1, NewTransaction(r1))
107+
router.Register(1, NewTransaction(r1))
108108

109109
// group 2
110110
s2 := store.NewRbMemoryStore()
111111
l2 := store.NewRbMemoryStoreWithExpire(time.Minute)
112112
r2, stop2 := newTestRaft(t, "2", NewKvFSM(s2, l2))
113113
defer stop2()
114-
stm.Register(2, NewTransaction(r2))
114+
router.Register(2, NewTransaction(r2))
115115

116116
reqs := []*pb.Request{
117117
{IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("b"), Value: []byte("v1")}}},
118118
{IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("x"), Value: []byte("v2")}}},
119119
}
120120

121-
_, err := stm.Commit(reqs)
121+
_, err := router.Commit(reqs)
122122
if err != nil {
123123
t.Fatalf("commit: %v", err)
124124
}
@@ -133,34 +133,34 @@ func TestShardedTransactionManagerCommit(t *testing.T) {
133133
}
134134
}
135135

136-
func TestShardedTransactionManagerSplitAndMerge(t *testing.T) {
136+
func TestShardRouterSplitAndMerge(t *testing.T) {
137137
ctx := context.Background()
138138

139139
e := distribution.NewEngine()
140140
// start with single shard handled by group 1
141141
e.UpdateRoute([]byte("a"), nil, 1)
142142

143-
stm := NewShardedTransactionManager(e)
143+
router := NewShardRouter(e)
144144

145145
// group 1
146146
s1 := store.NewRbMemoryStore()
147147
l1 := store.NewRbMemoryStoreWithExpire(time.Minute)
148148
r1, stop1 := newTestRaft(t, "1", NewKvFSM(s1, l1))
149149
defer stop1()
150-
stm.Register(1, NewTransaction(r1))
150+
router.Register(1, NewTransaction(r1))
151151

152152
// group 2 (will be used after split)
153153
s2 := store.NewRbMemoryStore()
154154
l2 := store.NewRbMemoryStoreWithExpire(time.Minute)
155155
r2, stop2 := newTestRaft(t, "2", NewKvFSM(s2, l2))
156156
defer stop2()
157-
stm.Register(2, NewTransaction(r2))
157+
router.Register(2, NewTransaction(r2))
158158

159159
// initial write routed to group 1
160160
req := []*pb.Request{
161161
{IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("b"), Value: []byte("v1")}}},
162162
}
163-
if _, err := stm.Commit(req); err != nil {
163+
if _, err := router.Commit(req); err != nil {
164164
t.Fatalf("commit group1: %v", err)
165165
}
166166
v, err := s1.Get(ctx, []byte("b"))
@@ -172,13 +172,13 @@ func TestShardedTransactionManagerSplitAndMerge(t *testing.T) {
172172
e2 := distribution.NewEngine()
173173
e2.UpdateRoute([]byte("a"), []byte("m"), 1)
174174
e2.UpdateRoute([]byte("m"), nil, 2)
175-
stm.engine = e2
175+
router.engine = e2
176176

177177
// write routed to group 2 after split
178178
req = []*pb.Request{
179179
{IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("x"), Value: []byte("v2")}}},
180180
}
181-
if _, err := stm.Commit(req); err != nil {
181+
if _, err := router.Commit(req); err != nil {
182182
t.Fatalf("commit group2: %v", err)
183183
}
184184
v, err = s2.Get(ctx, []byte("x"))
@@ -189,17 +189,66 @@ func TestShardedTransactionManagerSplitAndMerge(t *testing.T) {
189189
// merge shards back: all keys handled by group1
190190
e3 := distribution.NewEngine()
191191
e3.UpdateRoute([]byte("a"), nil, 1)
192-
stm.engine = e3
192+
router.engine = e3
193193

194194
// write routed to group1 after merge
195195
req = []*pb.Request{
196196
{IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("z"), Value: []byte("v3")}}},
197197
}
198-
if _, err := stm.Commit(req); err != nil {
198+
if _, err := router.Commit(req); err != nil {
199199
t.Fatalf("commit after merge: %v", err)
200200
}
201201
v, err = s1.Get(ctx, []byte("z"))
202202
if err != nil || string(v) != "v3" {
203203
t.Fatalf("group1 value after merge: %v %v", v, err)
204204
}
205205
}
206+
207+
type fakeTM struct {
208+
commitErr bool
209+
commitCalls int
210+
abortCalls int
211+
}
212+
213+
func (f *fakeTM) Commit(reqs []*pb.Request) (*TransactionResponse, error) {
214+
f.commitCalls++
215+
if f.commitErr {
216+
return nil, fmt.Errorf("commit fail")
217+
}
218+
return &TransactionResponse{}, nil
219+
}
220+
221+
func (f *fakeTM) Abort(reqs []*pb.Request) (*TransactionResponse, error) {
222+
f.abortCalls++
223+
return &TransactionResponse{}, nil
224+
}
225+
226+
func TestShardRouterCommitFailure(t *testing.T) {
227+
e := distribution.NewEngine()
228+
e.UpdateRoute([]byte("a"), []byte("m"), 1)
229+
e.UpdateRoute([]byte("m"), nil, 2)
230+
231+
router := NewShardRouter(e)
232+
233+
ok := &fakeTM{}
234+
fail := &fakeTM{commitErr: true}
235+
router.Register(1, ok)
236+
router.Register(2, fail)
237+
238+
reqs := []*pb.Request{
239+
{IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("b"), Value: []byte("v1")}}},
240+
{IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("x"), Value: []byte("v2")}}},
241+
}
242+
243+
if _, err := router.Commit(reqs); err == nil {
244+
t.Fatalf("expected error")
245+
}
246+
247+
if fail.commitCalls == 0 || ok.commitCalls == 0 {
248+
t.Fatalf("expected commits on both groups")
249+
}
250+
251+
if ok.abortCalls != 0 {
252+
t.Fatalf("unexpected abort on successful group")
253+
}
254+
}

0 commit comments

Comments
 (0)