Skip to content

Commit 80a021f

Browse files
committed
retrieve values during begin
1 parent 988d4dc commit 80a021f

File tree

2 files changed

+157
-26
lines changed

2 files changed

+157
-26
lines changed

atlas/socket/handler.go

Lines changed: 98 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,6 @@ ready:
264264
return
265265
}
266266

267-
var changes []*consensus.Migration
268-
269267
for {
270268
select {
271269
case <-ctx.Done():
@@ -298,30 +296,7 @@ ready:
298296
_, err = s.PerformExecute(ctx, cmd)
299297
goto handleError
300298
case "QUERY":
301-
var q *Query
302-
q, err = s.PerformQuery(ctx, cmd)
303-
if err != nil {
304-
goto handleError
305-
}
306-
if len(q.tables) > 0 {
307-
// we need ownership over all tables
308-
for _, table := range q.tables {
309-
qm := consensus.GetDefaultQuorumManager(ctx)
310-
var qu consensus.Quorum
311-
qu, err = qm.GetQuorum(ctx, table.GetName())
312-
if err != nil {
313-
err = makeFatal(err)
314-
goto handleError
315-
}
316-
317-
result, err = qu.StealTableOwnership(ctx, &consensus.StealTableOwnershipRequest{
318-
Sender: consensus.ConstructCurrentNode(),
319-
Reason: consensus.StealReason_schemaReason,
320-
Table: nil,
321-
})
322-
}
323-
}
324-
299+
_, err = s.PerformQuery(ctx, cmd)
325300
goto handleError
326301
case "FINALIZE":
327302
err = s.PerformFinalize(cmd)
@@ -505,6 +480,103 @@ ready:
505480
}
506481
}
507482

483+
func (s *Socket) SanitizeBegin(cmd commands.Command) (tables, views, triggers []string, err error) {
484+
if s.inTransaction {
485+
err = errors.New("the transaction is already in progress")
486+
return
487+
}
488+
489+
extractList := func() (list []string, err error) {
490+
var rip []string
491+
n := 3
492+
expectingFirst := true
493+
expectingLast := true
494+
for {
495+
first, _ := cmd.SelectNormalizedCommand(n)
496+
rip = append(rip, first)
497+
n++
498+
isLast := false
499+
first = strings.TrimSuffix(first, ",")
500+
if first == "" {
501+
break
502+
}
503+
if first == "(" {
504+
expectingFirst = false
505+
continue
506+
}
507+
if first == ")" {
508+
expectingLast = false
509+
break
510+
}
511+
if strings.HasPrefix(first, "(") {
512+
first = strings.TrimPrefix(first, "(")
513+
expectingFirst = false
514+
}
515+
if strings.HasSuffix(first, ")") {
516+
expectingLast = false
517+
first = strings.TrimSuffix(first, ")")
518+
isLast = true
519+
}
520+
if expectingFirst {
521+
err = errors.New("expected table name in parentheses")
522+
return
523+
}
524+
525+
if first == "," {
526+
continue
527+
}
528+
529+
list = append(list, cmd.NormalizeName(first))
530+
if isLast {
531+
break
532+
}
533+
}
534+
if expectingFirst || expectingLast {
535+
err = errors.New("expected table name in parentheses")
536+
return
537+
}
538+
cmd = cmd.ReplaceCommand(strings.Join(rip, " "), "BEGIN IMMEDIATE")
539+
return
540+
}
541+
542+
if t, ok := cmd.SelectNormalizedCommand(1); ok && t == "IMMEDIATE" {
543+
if err = cmd.CheckMinLen(3); err != nil {
544+
return
545+
}
546+
for {
547+
switch t, _ = cmd.SelectNormalizedCommand(2); t {
548+
case "TABLE":
549+
if err = cmd.CheckMinLen(4); err != nil {
550+
return
551+
}
552+
tables, err = extractList()
553+
if err != nil {
554+
return
555+
}
556+
case "VIEW":
557+
if err = cmd.CheckMinLen(4); err != nil {
558+
return
559+
}
560+
views, err = extractList()
561+
if err != nil {
562+
return
563+
}
564+
case "TRIGGER":
565+
if err = cmd.CheckMinLen(4); err != nil {
566+
return
567+
}
568+
triggers, err = extractList()
569+
if err != nil {
570+
return
571+
}
572+
default:
573+
return
574+
}
575+
}
576+
}
577+
return
578+
}
579+
508580
func (s *Socket) PerformFinalize(cmd *commands.CommandString) (err error) {
509581
if err = cmd.CheckExactLen(2); err != nil {
510582
return

atlas/socket/prepare_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,62 @@ func TestPrepare_Handle(t *testing.T) {
7676
})
7777
}
7878
}
79+
80+
func TestSanitizeBegin(t *testing.T) {
81+
tests := []struct {
82+
name string
83+
cmd commands.Command
84+
want [][]string
85+
wantErr bool
86+
}{
87+
{
88+
name: "Transaction in progress",
89+
cmd: commands.CommandFromString("BEGIN IMMEDIATE TABLE Users"),
90+
want: nil,
91+
wantErr: true,
92+
},
93+
{
94+
name: "Transaction in progress",
95+
cmd: commands.CommandFromString("BEGIN IMMEDIATE TABLE (Users"),
96+
want: nil,
97+
wantErr: true,
98+
},
99+
{
100+
name: "Valid command with tables",
101+
cmd: commands.CommandFromString("BEGIN IMMEDIATE TABLE (Users, Orders)"),
102+
want: [][]string{{"MAIN.USERS", "MAIN.ORDERS"}, nil, nil},
103+
wantErr: false,
104+
},
105+
{
106+
name: "Valid command with tables",
107+
cmd: commands.CommandFromString("BEGIN IMMEDIATE TABLE ( Users, Orders )"),
108+
want: [][]string{{"MAIN.USERS", "MAIN.ORDERS"}, nil, nil},
109+
wantErr: false,
110+
},
111+
{
112+
name: "Valid command with tables",
113+
cmd: commands.CommandFromString("BEGIN IMMEDIATE TABLE ( Users, Orders ) VIEW (STUFF)"),
114+
want: [][]string{{"MAIN.USERS", "MAIN.ORDERS"}, {"MAIN.STUFF"}, nil},
115+
wantErr: false,
116+
},
117+
{
118+
name: "Invalid command length",
119+
cmd: commands.CommandFromString("BEGIN IMMEDIATE"),
120+
want: nil,
121+
wantErr: true,
122+
},
123+
}
124+
125+
for _, tt := range tests {
126+
t.Run(tt.name, func(t *testing.T) {
127+
s := &Socket{}
128+
got, got2, got3, err := s.SanitizeBegin(tt.cmd)
129+
if tt.wantErr {
130+
assert.Error(t, err)
131+
} else {
132+
assert.NoError(t, err)
133+
assert.Equal(t, tt.want, [][]string{got, got2, got3})
134+
}
135+
})
136+
}
137+
}

0 commit comments

Comments
 (0)