Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 14 additions & 16 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,15 +372,15 @@ func (h *Handler) doQuery(
bindings map[string]*querypb.BindVariable,
callback func(*sqltypes.Result, bool) error,
qFlags *sql.QueryFlags,
) (string, error) {
sqlCtx, err := h.sm.NewContext(ctx, c, query)
) (remainder string, err error) {
var sqlCtx *sql.Context
sqlCtx, err = h.sm.NewContext(ctx, c, query)
if err != nil {
return "", err
}

start := time.Now()

var remainder string
var prequery string
if parsed == nil {
_, inPreparedCache := h.e.PreparedDataCache.GetCachedStmt(sqlCtx.Session.ID(), query)
Expand Down Expand Up @@ -411,23 +411,24 @@ func (h *Handler) doQuery(
sqlCtx.GetLogger().Debugf("Starting query")

finish := observeQuery(sqlCtx, query)
defer finish(err)
defer func() {
finish(err)
}()

sqlCtx.GetLogger().Tracef("beginning execution")

oCtx := ctx

// TODO: it would be nice to put this logic in the engine, not the handler, but we don't want the process to be
// marked done until we're done spooling rows over the wire
ctx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query)
defer func() {
if err != nil && ctx != nil {
sqlCtx.ProcessList.EndQuery(sqlCtx)
}
}()
sqlCtx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query)
if err != nil {
return remainder, err
}
defer sqlCtx.ProcessList.EndQuery(sqlCtx)

var schema sql.Schema
var rowIter sql.RowIter
qFlags.Set(sql.QFlagDeferProjections)
schema, rowIter, qFlags, err := queryExec(sqlCtx, query, parsed, analyzedPlan, bindings, qFlags)
schema, rowIter, qFlags, err = queryExec(sqlCtx, query, parsed, analyzedPlan, bindings, qFlags)
if err != nil {
sqlCtx.GetLogger().WithError(err).Warn("error running query")
if verboseErrorLogging {
Expand Down Expand Up @@ -455,9 +456,6 @@ func (h *Handler) doQuery(
return remainder, err
}

// errGroup context is now canceled
ctx = oCtx

if err = setConnStatusFlags(sqlCtx, c); err != nil {
return remainder, err
}
Expand Down
108 changes: 108 additions & 0 deletions server/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"time"

"github.com/dolthub/vitess/go/mysql"
"github.com/dolthub/vitess/go/race"
"github.com/dolthub/vitess/go/sqltypes"
"github.com/dolthub/vitess/go/vt/proto/query"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -742,6 +743,113 @@ func TestHandlerKill(t *testing.T) {
require.Len(handler.sm.sessions, 1)
}

func TestHandlerKillQuery(t *testing.T) {
if race.Enabled {
t.Skip("this test is inherently racey")
}
require := require.New(t)
e, pro := setupMemDB(require)
dbFunc := pro.Database

handler := &Handler{
e: e,
sm: NewSessionManager(
func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
return sql.NewBaseSessionWithClientServer(addr, sql.Client{Capabilities: conn.Capabilities}, conn.ConnectionID), nil
},
sql.NoopTracer,
dbFunc,
e.MemoryManager,
e.ProcessList,
"foo",
),
}

var err error
conn1 := newConn(1)
handler.NewConnection(conn1)

conn2 := newConn(2)
handler.NewConnection(conn2)

require.Len(handler.sm.connections, 2)
require.Len(handler.sm.sessions, 0)

handler.ComInitDB(conn1, "test")
err = handler.sm.SetDB(conn1, "test")
require.NoError(err)

err = handler.sm.SetDB(conn2, "test")
require.NoError(err)

require.False(conn1.Conn.(*mockConn).closed)
require.False(conn2.Conn.(*mockConn).closed)
require.Len(handler.sm.connections, 2)
require.Len(handler.sm.sessions, 2)

var wg sync.WaitGroup
wg.Add(1)
sleepQuery := "SELECT SLEEP(1)"
go func() {
defer wg.Done()
err = handler.ComQuery(context.Background(), conn1, sleepQuery, func(res *sqltypes.Result, more bool) error {
return nil
})
require.Error(err)
}()

time.Sleep(100 * time.Millisecond)
var sleepQueryID string
err = handler.ComQuery(context.Background(), conn2, "SHOW PROCESSLIST", func(res *sqltypes.Result, more bool) error {
// 1, , , test, Query, 0, ... , SELECT SLEEP(1000)
// 2, , , test, Query, 0, running, SHOW PROCESSLIST
require.Equal(2, len(res.Rows))
hasSleepQuery := false
for _, row := range res.Rows {
if row[7].ToString() != sleepQuery {
continue
}
hasSleepQuery = true
sleepQueryID = row[0].ToString()
require.Equal("Query", row[4].ToString())
}
require.True(hasSleepQuery)
return nil
})
require.NoError(err)

time.Sleep(100 * time.Millisecond)
err = handler.ComQuery(context.Background(), conn2, "KILL QUERY "+sleepQueryID, func(res *sqltypes.Result, more bool) error {
return nil
})
require.NoError(err)
wg.Wait()

time.Sleep(100 * time.Millisecond)
err = handler.ComQuery(context.Background(), conn2, "SHOW PROCESSLIST", func(res *sqltypes.Result, more bool) error {
// 1, , , test, Sleep, 0, ,
// 2, , , test, Query, 0, running, SHOW PROCESSLIST
require.Equal(2, len(res.Rows))
hasSleepQueryID := false
for _, row := range res.Rows {
if row[0].ToString() != sleepQueryID {
continue
}
hasSleepQueryID = true
require.Equal("Sleep", row[4].ToString())
require.Equal("", row[7].ToString())
}
require.True(hasSleepQueryID)
return nil
})
require.NoError(err)

require.False(conn1.Conn.(*mockConn).closed)
require.False(conn2.Conn.(*mockConn).closed)
require.Len(handler.sm.connections, 2)
require.Len(handler.sm.sessions, 2)
}

func TestSchemaToFields(t *testing.T) {
require := require.New(t)

Expand Down