diff --git a/server/handler.go b/server/handler.go index a6e0afd399..9736b8f035 100644 --- a/server/handler.go +++ b/server/handler.go @@ -372,15 +372,15 @@ func (h *Handler) doQuery( bindings map[string]*querypb.BindVariable, callback func(*sqltypes.Result, bool) error, qFlags *sql.QueryFlags, -) (string, error) { - sqlCtx, err := h.sm.NewContext(ctx, c, query) +) (remainder string, err error) { + var sqlCtx *sql.Context + sqlCtx, err = h.sm.NewContext(ctx, c, query) if err != nil { return "", err } start := time.Now() - var remainder string var prequery string if parsed == nil { _, inPreparedCache := h.e.PreparedDataCache.GetCachedStmt(sqlCtx.Session.ID(), query) @@ -411,23 +411,24 @@ func (h *Handler) doQuery( sqlCtx.GetLogger().Debugf("Starting query") finish := observeQuery(sqlCtx, query) - defer finish(err) + defer func() { + finish(err) + }() sqlCtx.GetLogger().Tracef("beginning execution") - oCtx := ctx - // TODO: it would be nice to put this logic in the engine, not the handler, but we don't want the process to be // marked done until we're done spooling rows over the wire - ctx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query) - defer func() { - if err != nil && ctx != nil { - sqlCtx.ProcessList.EndQuery(sqlCtx) - } - }() + sqlCtx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query) + if err != nil { + return remainder, err + } + defer sqlCtx.ProcessList.EndQuery(sqlCtx) + var schema sql.Schema + var rowIter sql.RowIter qFlags.Set(sql.QFlagDeferProjections) - schema, rowIter, qFlags, err := queryExec(sqlCtx, query, parsed, analyzedPlan, bindings, qFlags) + schema, rowIter, qFlags, err = queryExec(sqlCtx, query, parsed, analyzedPlan, bindings, qFlags) if err != nil { sqlCtx.GetLogger().WithError(err).Warn("error running query") if verboseErrorLogging { @@ -455,9 +456,6 @@ func (h *Handler) doQuery( return remainder, err } - // errGroup context is now canceled - ctx = oCtx - if err = setConnStatusFlags(sqlCtx, c); err != nil { return remainder, err } diff --git a/server/handler_test.go b/server/handler_test.go index 38707c79a5..6aac1c4c22 100644 --- a/server/handler_test.go +++ b/server/handler_test.go @@ -25,6 +25,7 @@ import ( "time" "github.com/dolthub/vitess/go/mysql" + "github.com/dolthub/vitess/go/race" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "github.com/stretchr/testify/assert" @@ -742,6 +743,113 @@ func TestHandlerKill(t *testing.T) { require.Len(handler.sm.sessions, 1) } +func TestHandlerKillQuery(t *testing.T) { + if race.Enabled { + t.Skip("this test is inherently racey") + } + require := require.New(t) + e, pro := setupMemDB(require) + dbFunc := pro.Database + + handler := &Handler{ + e: e, + sm: NewSessionManager( + func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) { + return sql.NewBaseSessionWithClientServer(addr, sql.Client{Capabilities: conn.Capabilities}, conn.ConnectionID), nil + }, + sql.NoopTracer, + dbFunc, + e.MemoryManager, + e.ProcessList, + "foo", + ), + } + + var err error + conn1 := newConn(1) + handler.NewConnection(conn1) + + conn2 := newConn(2) + handler.NewConnection(conn2) + + require.Len(handler.sm.connections, 2) + require.Len(handler.sm.sessions, 0) + + handler.ComInitDB(conn1, "test") + err = handler.sm.SetDB(conn1, "test") + require.NoError(err) + + err = handler.sm.SetDB(conn2, "test") + require.NoError(err) + + require.False(conn1.Conn.(*mockConn).closed) + require.False(conn2.Conn.(*mockConn).closed) + require.Len(handler.sm.connections, 2) + require.Len(handler.sm.sessions, 2) + + var wg sync.WaitGroup + wg.Add(1) + sleepQuery := "SELECT SLEEP(1)" + go func() { + defer wg.Done() + err = handler.ComQuery(context.Background(), conn1, sleepQuery, func(res *sqltypes.Result, more bool) error { + return nil + }) + require.Error(err) + }() + + time.Sleep(100 * time.Millisecond) + var sleepQueryID string + err = handler.ComQuery(context.Background(), conn2, "SHOW PROCESSLIST", func(res *sqltypes.Result, more bool) error { + // 1, , , test, Query, 0, ... , SELECT SLEEP(1000) + // 2, , , test, Query, 0, running, SHOW PROCESSLIST + require.Equal(2, len(res.Rows)) + hasSleepQuery := false + for _, row := range res.Rows { + if row[7].ToString() != sleepQuery { + continue + } + hasSleepQuery = true + sleepQueryID = row[0].ToString() + require.Equal("Query", row[4].ToString()) + } + require.True(hasSleepQuery) + return nil + }) + require.NoError(err) + + time.Sleep(100 * time.Millisecond) + err = handler.ComQuery(context.Background(), conn2, "KILL QUERY "+sleepQueryID, func(res *sqltypes.Result, more bool) error { + return nil + }) + require.NoError(err) + wg.Wait() + + time.Sleep(100 * time.Millisecond) + err = handler.ComQuery(context.Background(), conn2, "SHOW PROCESSLIST", func(res *sqltypes.Result, more bool) error { + // 1, , , test, Sleep, 0, , + // 2, , , test, Query, 0, running, SHOW PROCESSLIST + require.Equal(2, len(res.Rows)) + hasSleepQueryID := false + for _, row := range res.Rows { + if row[0].ToString() != sleepQueryID { + continue + } + hasSleepQueryID = true + require.Equal("Sleep", row[4].ToString()) + require.Equal("", row[7].ToString()) + } + require.True(hasSleepQueryID) + return nil + }) + require.NoError(err) + + require.False(conn1.Conn.(*mockConn).closed) + require.False(conn2.Conn.(*mockConn).closed) + require.Len(handler.sm.connections, 2) + require.Len(handler.sm.sessions, 2) +} + func TestSchemaToFields(t *testing.T) { require := require.New(t)