Skip to content

Commit b637a3f

Browse files
author
James Cor
committed
Merge branch 'main' into james/update
2 parents 973bb30 + a6973b5 commit b637a3f

File tree

2 files changed

+122
-16
lines changed

2 files changed

+122
-16
lines changed

server/handler.go

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -372,15 +372,15 @@ func (h *Handler) doQuery(
372372
bindings map[string]*querypb.BindVariable,
373373
callback func(*sqltypes.Result, bool) error,
374374
qFlags *sql.QueryFlags,
375-
) (string, error) {
376-
sqlCtx, err := h.sm.NewContext(ctx, c, query)
375+
) (remainder string, err error) {
376+
var sqlCtx *sql.Context
377+
sqlCtx, err = h.sm.NewContext(ctx, c, query)
377378
if err != nil {
378379
return "", err
379380
}
380381

381382
start := time.Now()
382383

383-
var remainder string
384384
var prequery string
385385
if parsed == nil {
386386
_, inPreparedCache := h.e.PreparedDataCache.GetCachedStmt(sqlCtx.Session.ID(), query)
@@ -411,23 +411,24 @@ func (h *Handler) doQuery(
411411
sqlCtx.GetLogger().Debugf("Starting query")
412412

413413
finish := observeQuery(sqlCtx, query)
414-
defer finish(err)
414+
defer func() {
415+
finish(err)
416+
}()
415417

416418
sqlCtx.GetLogger().Tracef("beginning execution")
417419

418-
oCtx := ctx
419-
420420
// TODO: it would be nice to put this logic in the engine, not the handler, but we don't want the process to be
421421
// marked done until we're done spooling rows over the wire
422-
ctx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query)
423-
defer func() {
424-
if err != nil && ctx != nil {
425-
sqlCtx.ProcessList.EndQuery(sqlCtx)
426-
}
427-
}()
422+
sqlCtx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query)
423+
if err != nil {
424+
return remainder, err
425+
}
426+
defer sqlCtx.ProcessList.EndQuery(sqlCtx)
428427

428+
var schema sql.Schema
429+
var rowIter sql.RowIter
429430
qFlags.Set(sql.QFlagDeferProjections)
430-
schema, rowIter, qFlags, err := queryExec(sqlCtx, query, parsed, analyzedPlan, bindings, qFlags)
431+
schema, rowIter, qFlags, err = queryExec(sqlCtx, query, parsed, analyzedPlan, bindings, qFlags)
431432
if err != nil {
432433
sqlCtx.GetLogger().WithError(err).Warn("error running query")
433434
if verboseErrorLogging {
@@ -455,9 +456,6 @@ func (h *Handler) doQuery(
455456
return remainder, err
456457
}
457458

458-
// errGroup context is now canceled
459-
ctx = oCtx
460-
461459
if err = setConnStatusFlags(sqlCtx, c); err != nil {
462460
return remainder, err
463461
}

server/handler_test.go

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"time"
2626

2727
"github.com/dolthub/vitess/go/mysql"
28+
"github.com/dolthub/vitess/go/race"
2829
"github.com/dolthub/vitess/go/sqltypes"
2930
"github.com/dolthub/vitess/go/vt/proto/query"
3031
"github.com/stretchr/testify/assert"
@@ -742,6 +743,113 @@ func TestHandlerKill(t *testing.T) {
742743
require.Len(handler.sm.sessions, 1)
743744
}
744745

746+
func TestHandlerKillQuery(t *testing.T) {
747+
if race.Enabled {
748+
t.Skip("this test is inherently racey")
749+
}
750+
require := require.New(t)
751+
e, pro := setupMemDB(require)
752+
dbFunc := pro.Database
753+
754+
handler := &Handler{
755+
e: e,
756+
sm: NewSessionManager(
757+
func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
758+
return sql.NewBaseSessionWithClientServer(addr, sql.Client{Capabilities: conn.Capabilities}, conn.ConnectionID), nil
759+
},
760+
sql.NoopTracer,
761+
dbFunc,
762+
e.MemoryManager,
763+
e.ProcessList,
764+
"foo",
765+
),
766+
}
767+
768+
var err error
769+
conn1 := newConn(1)
770+
handler.NewConnection(conn1)
771+
772+
conn2 := newConn(2)
773+
handler.NewConnection(conn2)
774+
775+
require.Len(handler.sm.connections, 2)
776+
require.Len(handler.sm.sessions, 0)
777+
778+
handler.ComInitDB(conn1, "test")
779+
err = handler.sm.SetDB(conn1, "test")
780+
require.NoError(err)
781+
782+
err = handler.sm.SetDB(conn2, "test")
783+
require.NoError(err)
784+
785+
require.False(conn1.Conn.(*mockConn).closed)
786+
require.False(conn2.Conn.(*mockConn).closed)
787+
require.Len(handler.sm.connections, 2)
788+
require.Len(handler.sm.sessions, 2)
789+
790+
var wg sync.WaitGroup
791+
wg.Add(1)
792+
sleepQuery := "SELECT SLEEP(1)"
793+
go func() {
794+
defer wg.Done()
795+
err = handler.ComQuery(context.Background(), conn1, sleepQuery, func(res *sqltypes.Result, more bool) error {
796+
return nil
797+
})
798+
require.Error(err)
799+
}()
800+
801+
time.Sleep(100 * time.Millisecond)
802+
var sleepQueryID string
803+
err = handler.ComQuery(context.Background(), conn2, "SHOW PROCESSLIST", func(res *sqltypes.Result, more bool) error {
804+
// 1, , , test, Query, 0, ... , SELECT SLEEP(1000)
805+
// 2, , , test, Query, 0, running, SHOW PROCESSLIST
806+
require.Equal(2, len(res.Rows))
807+
hasSleepQuery := false
808+
for _, row := range res.Rows {
809+
if row[7].ToString() != sleepQuery {
810+
continue
811+
}
812+
hasSleepQuery = true
813+
sleepQueryID = row[0].ToString()
814+
require.Equal("Query", row[4].ToString())
815+
}
816+
require.True(hasSleepQuery)
817+
return nil
818+
})
819+
require.NoError(err)
820+
821+
time.Sleep(100 * time.Millisecond)
822+
err = handler.ComQuery(context.Background(), conn2, "KILL QUERY "+sleepQueryID, func(res *sqltypes.Result, more bool) error {
823+
return nil
824+
})
825+
require.NoError(err)
826+
wg.Wait()
827+
828+
time.Sleep(100 * time.Millisecond)
829+
err = handler.ComQuery(context.Background(), conn2, "SHOW PROCESSLIST", func(res *sqltypes.Result, more bool) error {
830+
// 1, , , test, Sleep, 0, ,
831+
// 2, , , test, Query, 0, running, SHOW PROCESSLIST
832+
require.Equal(2, len(res.Rows))
833+
hasSleepQueryID := false
834+
for _, row := range res.Rows {
835+
if row[0].ToString() != sleepQueryID {
836+
continue
837+
}
838+
hasSleepQueryID = true
839+
require.Equal("Sleep", row[4].ToString())
840+
require.Equal("", row[7].ToString())
841+
}
842+
require.True(hasSleepQueryID)
843+
return nil
844+
})
845+
require.NoError(err)
846+
847+
require.False(conn1.Conn.(*mockConn).closed)
848+
require.False(conn2.Conn.(*mockConn).closed)
849+
require.Len(handler.sm.connections, 2)
850+
require.Len(handler.sm.sessions, 2)
851+
}
852+
745853
func TestSchemaToFields(t *testing.T) {
746854
require := require.New(t)
747855

0 commit comments

Comments
 (0)