Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -415,11 +415,9 @@ func (h *Handler) doQuery(

sqlCtx.GetLogger().Tracef("beginning execution")

oCtx := ctx

// TODO: it would be nice to put this logic in the engine, not the handler, but we don't want the process to be
// marked done until we're done spooling rows over the wire
ctx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query)
sqlCtx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query)
defer func() {
if err != nil && ctx != nil {
sqlCtx.ProcessList.EndQuery(sqlCtx)
Expand Down Expand Up @@ -455,9 +453,6 @@ func (h *Handler) doQuery(
return remainder, err
}

// errGroup context is now canceled
ctx = oCtx

if err = setConnStatusFlags(sqlCtx, c); err != nil {
return remainder, err
}
Expand Down
108 changes: 108 additions & 0 deletions server/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"time"

"github.com/dolthub/vitess/go/mysql"
"github.com/dolthub/vitess/go/race"
"github.com/dolthub/vitess/go/sqltypes"
"github.com/dolthub/vitess/go/vt/proto/query"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -742,6 +743,113 @@ func TestHandlerKill(t *testing.T) {
require.Len(handler.sm.sessions, 1)
}

func TestHandlerKillQuery(t *testing.T) {
if race.Enabled {
t.Skip("this test is inherently racey")
}
require := require.New(t)
e, pro := setupMemDB(require)
dbFunc := pro.Database

handler := &Handler{
e: e,
sm: NewSessionManager(
func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
return sql.NewBaseSessionWithClientServer(addr, sql.Client{Capabilities: conn.Capabilities}, conn.ConnectionID), nil
},
sql.NoopTracer,
dbFunc,
e.MemoryManager,
e.ProcessList,
"foo",
),
}

var err error
conn1 := newConn(1)
handler.NewConnection(conn1)

conn2 := newConn(2)
handler.NewConnection(conn2)

require.Len(handler.sm.connections, 2)
require.Len(handler.sm.sessions, 0)

handler.ComInitDB(conn1, "test")
err = handler.sm.SetDB(conn1, "test")
require.NoError(err)

err = handler.sm.SetDB(conn2, "test")
require.NoError(err)

require.False(conn1.Conn.(*mockConn).closed)
require.False(conn2.Conn.(*mockConn).closed)
require.Len(handler.sm.connections, 2)
require.Len(handler.sm.sessions, 2)

var wg sync.WaitGroup
wg.Add(1)
sleepQuery := "SELECT SLEEP(1)"
go func() {
defer wg.Done()
err = handler.ComQuery(context.Background(), conn1, sleepQuery, func(res *sqltypes.Result, more bool) error {
return nil
})
require.Error(err)
}()

time.Sleep(100 * time.Millisecond)
var sleepQueryID string
err = handler.ComQuery(context.Background(), conn2, "SHOW PROCESSLIST", func(res *sqltypes.Result, more bool) error {
// 1, , , test, Query, 0, ... , SELECT SLEEP(1000)
// 2, , , test, Query, 0, running, SHOW PROCESSLIST
require.Equal(2, len(res.Rows))
hasSleepQuery := false
for _, row := range res.Rows {
if row[7].ToString() != sleepQuery {
continue
}
hasSleepQuery = true
sleepQueryID = row[0].ToString()
require.Equal("Query", row[4].ToString())
}
require.True(hasSleepQuery)
return nil
})
require.NoError(err)

time.Sleep(100 * time.Millisecond)
err = handler.ComQuery(context.Background(), conn2, "KILL QUERY "+sleepQueryID, func(res *sqltypes.Result, more bool) error {
return nil
})
require.NoError(err)
wg.Wait()

time.Sleep(100 * time.Millisecond)
err = handler.ComQuery(context.Background(), conn2, "SHOW PROCESSLIST", func(res *sqltypes.Result, more bool) error {
// 1, , , test, Sleep, 0, ,
// 2, , , test, Query, 0, running, SHOW PROCESSLIST
require.Equal(2, len(res.Rows))
hasSleepQueryID := false
for _, row := range res.Rows {
if row[0].ToString() != sleepQueryID {
continue
}
hasSleepQueryID = true
require.Equal("Sleep", row[4].ToString())
require.Equal("", row[7].ToString())
}
require.True(hasSleepQueryID)
return nil
})
require.NoError(err)

require.False(conn1.Conn.(*mockConn).closed)
require.False(conn2.Conn.(*mockConn).closed)
require.Len(handler.sm.connections, 2)
require.Len(handler.sm.sessions, 2)
}

func TestSchemaToFields(t *testing.T) {
require := require.New(t)

Expand Down