From bcc247df1d576e8d129fd12c05c8c33bbe54d9e5 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 24 Oct 2024 18:22:29 -0700 Subject: [PATCH 01/10] actually use cancel context --- processlist.go | 5 ++++- server/handler.go | 8 ++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/processlist.go b/processlist.go index 9069e7359c..344b2a2a84 100644 --- a/processlist.go +++ b/processlist.go @@ -135,7 +135,10 @@ func (pl *ProcessList) BeginQuery( p.Query = query p.QueryPid = pid p.StartedAt = time.Now() - p.Kill = cancel + p.Kill = func(){ + print("KILL QUERY!!!\n") + cancel() + } p.Progress = make(map[string]sql.TableProgress) pl.byQueryPid[ctx.Pid()] = ctx.Session.ID() diff --git a/server/handler.go b/server/handler.go index a6e0afd399..9e63fcc734 100644 --- a/server/handler.go +++ b/server/handler.go @@ -415,14 +415,14 @@ func (h *Handler) doQuery( sqlCtx.GetLogger().Tracef("beginning execution") - oCtx := ctx + //oCtx := ctx // TODO: it would be nice to put this logic in the engine, not the handler, but we don't want the process to be // marked done until we're done spooling rows over the wire - ctx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query) + sqlCtx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query) defer func() { if err != nil && ctx != nil { - sqlCtx.ProcessList.EndQuery(sqlCtx) + sqlCtx.ProcessList.EndQuery(sqlCtx) // TODO: should this be ctx? } }() @@ -456,7 +456,7 @@ func (h *Handler) doQuery( } // errGroup context is now canceled - ctx = oCtx + //ctx = oCtx if err = setConnStatusFlags(sqlCtx, c); err != nil { return remainder, err From 87f96d5847dacaff8424330a02ac657083d173a6 Mon Sep 17 00:00:00 2001 From: jycor Date: Fri, 25 Oct 2024 06:48:42 +0000 Subject: [PATCH 02/10] [ga-format-pr] Run ./format_repo.sh to fix formatting --- processlist.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/processlist.go b/processlist.go index 344b2a2a84..b19a8b2afc 100644 --- a/processlist.go +++ b/processlist.go @@ -135,7 +135,7 @@ func (pl *ProcessList) BeginQuery( p.Query = query p.QueryPid = pid p.StartedAt = time.Now() - p.Kill = func(){ + p.Kill = func() { print("KILL QUERY!!!\n") cancel() } From a5265e932688ae3f7472eaeb9ad329033a4dea11 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 25 Oct 2024 00:05:27 -0700 Subject: [PATCH 03/10] remove todos --- server/handler.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/server/handler.go b/server/handler.go index 9e63fcc734..4cbca834e0 100644 --- a/server/handler.go +++ b/server/handler.go @@ -415,14 +415,12 @@ func (h *Handler) doQuery( sqlCtx.GetLogger().Tracef("beginning execution") - //oCtx := ctx - // TODO: it would be nice to put this logic in the engine, not the handler, but we don't want the process to be // marked done until we're done spooling rows over the wire sqlCtx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query) defer func() { if err != nil && ctx != nil { - sqlCtx.ProcessList.EndQuery(sqlCtx) // TODO: should this be ctx? + sqlCtx.ProcessList.EndQuery(sqlCtx) } }() @@ -455,9 +453,6 @@ func (h *Handler) doQuery( return remainder, err } - // errGroup context is now canceled - //ctx = oCtx - if err = setConnStatusFlags(sqlCtx, c); err != nil { return remainder, err } From 9572a0ddc331c5524f03238b7de0be95d85e6522 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 25 Oct 2024 00:37:45 -0700 Subject: [PATCH 04/10] remove debug --- processlist.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/processlist.go b/processlist.go index b19a8b2afc..9069e7359c 100644 --- a/processlist.go +++ b/processlist.go @@ -135,10 +135,7 @@ func (pl *ProcessList) BeginQuery( p.Query = query p.QueryPid = pid p.StartedAt = time.Now() - p.Kill = func() { - print("KILL QUERY!!!\n") - cancel() - } + p.Kill = cancel p.Progress = make(map[string]sql.TableProgress) pl.byQueryPid[ctx.Pid()] = ctx.Session.ID() From 4e2eed29fcd5a41e71780b926e66cace565d36b0 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 25 Oct 2024 16:27:42 -0700 Subject: [PATCH 05/10] handler test --- server/handler_test.go | 86 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/server/handler_test.go b/server/handler_test.go index 38707c79a5..c204c59aa2 100644 --- a/server/handler_test.go +++ b/server/handler_test.go @@ -742,6 +742,92 @@ func TestHandlerKill(t *testing.T) { require.Len(handler.sm.sessions, 1) } +func TestHandlerKillQuery(t *testing.T) { + require := require.New(t) + e, pro := setupMemDB(require) + dbFunc := pro.Database + + handler := &Handler{ + e: e, + sm: NewSessionManager( + func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) { + return sql.NewBaseSessionWithClientServer(addr, sql.Client{Capabilities: conn.Capabilities}, conn.ConnectionID), nil + }, + sql.NoopTracer, + dbFunc, + e.MemoryManager, + e.ProcessList, + "foo", + ), + } + + var err error + conn1 := newConn(1) + handler.NewConnection(conn1) + + conn2 := newConn(2) + handler.NewConnection(conn2) + + require.Len(handler.sm.connections, 2) + require.Len(handler.sm.sessions, 0) + + handler.ComInitDB(conn1, "test") + err = handler.sm.SetDB(conn1, "test") + require.NoError(err) + + err = handler.sm.SetDB(conn2, "test") + require.NoError(err) + + require.False(conn1.Conn.(*mockConn).closed) + require.False(conn2.Conn.(*mockConn).closed) + require.Len(handler.sm.connections, 2) + require.Len(handler.sm.sessions, 2) + + + sleepQuery := "SELECT SLEEP(1000)" + go func() { + err = handler.ComQuery(context.Background(), conn1, sleepQuery, func(res *sqltypes.Result, more bool) error { + return nil + }) + require.Error(err) + }() + + time.Sleep(100 * time.Millisecond) + var sleepQueryID string + err = handler.ComQuery(context.Background(), conn2, "SHOW PROCESSLIST", func(res *sqltypes.Result, more bool) error { + // 1, , , test, Query, 0, ... , SELECT SLEEP(1000) + // 2, , , test, Query, 0, running, SHOW PROCESSLIST + require.Equal(2, len(res.Rows)) + sleepQueryID = res.Rows[0][0].ToString() + require.Equal("Query", res.Rows[0][4].ToString()) + require.Equal(sleepQuery, res.Rows[0][7].ToString()) + return nil + }) + require.NoError(err) + + time.Sleep(100 * time.Millisecond) + err = handler.ComQuery(context.Background(), conn2, "KILL QUERY " + sleepQueryID, func(res *sqltypes.Result, more bool) error { + return nil + }) + require.NoError(err) + + time.Sleep(100 * time.Millisecond) + err = handler.ComQuery(context.Background(), conn2, "SHOW PROCESSLIST", func(res *sqltypes.Result, more bool) error { + // 1, , , test, Sleep, 0, , + // 2, , , test, Query, 0, running, SHOW PROCESSLIST + require.Equal(2, len(res.Rows)) + require.Equal("Sleep", res.Rows[0][4].ToString()) + require.Equal("", res.Rows[0][7].ToString()) + return nil + }) + require.NoError(err) + + require.False(conn1.Conn.(*mockConn).closed) + require.False(conn2.Conn.(*mockConn).closed) + require.Len(handler.sm.connections, 2) + require.Len(handler.sm.sessions, 2) +} + func TestSchemaToFields(t *testing.T) { require := require.New(t) From 3656821255102d5bec2d13bb82d89dd851613604 Mon Sep 17 00:00:00 2001 From: jycor Date: Fri, 25 Oct 2024 23:29:03 +0000 Subject: [PATCH 06/10] [ga-format-pr] Run ./format_repo.sh to fix formatting --- server/handler_test.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/server/handler_test.go b/server/handler_test.go index c204c59aa2..35685db5fe 100644 --- a/server/handler_test.go +++ b/server/handler_test.go @@ -783,7 +783,6 @@ func TestHandlerKillQuery(t *testing.T) { require.Len(handler.sm.connections, 2) require.Len(handler.sm.sessions, 2) - sleepQuery := "SELECT SLEEP(1000)" go func() { err = handler.ComQuery(context.Background(), conn1, sleepQuery, func(res *sqltypes.Result, more bool) error { @@ -806,7 +805,7 @@ func TestHandlerKillQuery(t *testing.T) { require.NoError(err) time.Sleep(100 * time.Millisecond) - err = handler.ComQuery(context.Background(), conn2, "KILL QUERY " + sleepQueryID, func(res *sqltypes.Result, more bool) error { + err = handler.ComQuery(context.Background(), conn2, "KILL QUERY "+sleepQueryID, func(res *sqltypes.Result, more bool) error { return nil }) require.NoError(err) @@ -817,7 +816,7 @@ func TestHandlerKillQuery(t *testing.T) { // 2, , , test, Query, 0, running, SHOW PROCESSLIST require.Equal(2, len(res.Rows)) require.Equal("Sleep", res.Rows[0][4].ToString()) - require.Equal("", res.Rows[0][7].ToString()) + require.Equal("", res.Rows[0][7].ToString()) return nil }) require.NoError(err) From 5cb33469c075c94f255757ff9389a57f8c373956 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 25 Oct 2024 16:49:58 -0700 Subject: [PATCH 07/10] add waits --- server/handler_test.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/server/handler_test.go b/server/handler_test.go index 35685db5fe..832a935ece 100644 --- a/server/handler_test.go +++ b/server/handler_test.go @@ -783,8 +783,11 @@ func TestHandlerKillQuery(t *testing.T) { require.Len(handler.sm.connections, 2) require.Len(handler.sm.sessions, 2) - sleepQuery := "SELECT SLEEP(1000)" + var wg sync.WaitGroup + wg.Add(1) + sleepQuery := "SELECT SLEEP(100)" go func() { + defer wg.Done() err = handler.ComQuery(context.Background(), conn1, sleepQuery, func(res *sqltypes.Result, more bool) error { return nil }) @@ -821,6 +824,8 @@ func TestHandlerKillQuery(t *testing.T) { }) require.NoError(err) + wg.Wait() + require.False(conn1.Conn.(*mockConn).closed) require.False(conn2.Conn.(*mockConn).closed) require.Len(handler.sm.connections, 2) From 68eae6b32cdfba10f76943475f2ec46468408447 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 25 Oct 2024 17:21:23 -0700 Subject: [PATCH 08/10] fix flake and race --- server/handler_test.go | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/server/handler_test.go b/server/handler_test.go index 832a935ece..4ba172cc0d 100644 --- a/server/handler_test.go +++ b/server/handler_test.go @@ -25,6 +25,7 @@ import ( "time" "github.com/dolthub/vitess/go/mysql" + "github.com/dolthub/vitess/go/race" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "github.com/stretchr/testify/assert" @@ -743,6 +744,9 @@ func TestHandlerKill(t *testing.T) { } func TestHandlerKillQuery(t *testing.T) { + if race.Enabled { + t.Skip("this test is inherently racey") + } require := require.New(t) e, pro := setupMemDB(require) dbFunc := pro.Database @@ -785,7 +789,7 @@ func TestHandlerKillQuery(t *testing.T) { var wg sync.WaitGroup wg.Add(1) - sleepQuery := "SELECT SLEEP(100)" + sleepQuery := "SELECT SLEEP(1)" go func() { defer wg.Done() err = handler.ComQuery(context.Background(), conn1, sleepQuery, func(res *sqltypes.Result, more bool) error { @@ -800,9 +804,16 @@ func TestHandlerKillQuery(t *testing.T) { // 1, , , test, Query, 0, ... , SELECT SLEEP(1000) // 2, , , test, Query, 0, running, SHOW PROCESSLIST require.Equal(2, len(res.Rows)) - sleepQueryID = res.Rows[0][0].ToString() - require.Equal("Query", res.Rows[0][4].ToString()) - require.Equal(sleepQuery, res.Rows[0][7].ToString()) + hasSleepQuery := false + for _, row := range res.Rows { + if row[7].ToString() != sleepQuery { + continue + } + hasSleepQuery = true + sleepQueryID = row[0].ToString() + require.Equal("Query", row[4].ToString()) + } + require.True(hasSleepQuery) return nil }) require.NoError(err) @@ -812,19 +823,26 @@ func TestHandlerKillQuery(t *testing.T) { return nil }) require.NoError(err) + wg.Wait() time.Sleep(100 * time.Millisecond) err = handler.ComQuery(context.Background(), conn2, "SHOW PROCESSLIST", func(res *sqltypes.Result, more bool) error { // 1, , , test, Sleep, 0, , // 2, , , test, Query, 0, running, SHOW PROCESSLIST require.Equal(2, len(res.Rows)) - require.Equal("Sleep", res.Rows[0][4].ToString()) - require.Equal("", res.Rows[0][7].ToString()) + hasSleepQueryID := false + for _, row := range res.Rows { + if row[0].ToString() != sleepQueryID { + continue + } + hasSleepQueryID = true + require.Equal("Sleep", row[4].ToString()) + require.Equal("", row[7].ToString()) + } + require.True(hasSleepQueryID) return nil }) - require.NoError(err) - - wg.Wait() + require.NoError(err)g require.False(conn1.Conn.(*mockConn).closed) require.False(conn2.Conn.(*mockConn).closed) From 3ba9230b60fcc8f799c76ef8cc66c20cc382ebec Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 25 Oct 2024 17:21:43 -0700 Subject: [PATCH 09/10] g --- server/handler_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/handler_test.go b/server/handler_test.go index 4ba172cc0d..6aac1c4c22 100644 --- a/server/handler_test.go +++ b/server/handler_test.go @@ -842,7 +842,7 @@ func TestHandlerKillQuery(t *testing.T) { require.True(hasSleepQueryID) return nil }) - require.NoError(err)g + require.NoError(err) require.False(conn1.Conn.(*mockConn).closed) require.False(conn2.Conn.(*mockConn).closed) From 9fa6d1b91d20398330b5daf7a76d7150c15fb3ef Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Mon, 28 Oct 2024 07:11:23 +0100 Subject: [PATCH 10/10] server/handle.go: doQuery cleanup of some error handling. (#2723) --- server/handler.go | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/server/handler.go b/server/handler.go index 4cbca834e0..9736b8f035 100644 --- a/server/handler.go +++ b/server/handler.go @@ -372,15 +372,15 @@ func (h *Handler) doQuery( bindings map[string]*querypb.BindVariable, callback func(*sqltypes.Result, bool) error, qFlags *sql.QueryFlags, -) (string, error) { - sqlCtx, err := h.sm.NewContext(ctx, c, query) +) (remainder string, err error) { + var sqlCtx *sql.Context + sqlCtx, err = h.sm.NewContext(ctx, c, query) if err != nil { return "", err } start := time.Now() - var remainder string var prequery string if parsed == nil { _, inPreparedCache := h.e.PreparedDataCache.GetCachedStmt(sqlCtx.Session.ID(), query) @@ -411,21 +411,24 @@ func (h *Handler) doQuery( sqlCtx.GetLogger().Debugf("Starting query") finish := observeQuery(sqlCtx, query) - defer finish(err) + defer func() { + finish(err) + }() sqlCtx.GetLogger().Tracef("beginning execution") // TODO: it would be nice to put this logic in the engine, not the handler, but we don't want the process to be // marked done until we're done spooling rows over the wire sqlCtx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query) - defer func() { - if err != nil && ctx != nil { - sqlCtx.ProcessList.EndQuery(sqlCtx) - } - }() + if err != nil { + return remainder, err + } + defer sqlCtx.ProcessList.EndQuery(sqlCtx) + var schema sql.Schema + var rowIter sql.RowIter qFlags.Set(sql.QFlagDeferProjections) - schema, rowIter, qFlags, err := queryExec(sqlCtx, query, parsed, analyzedPlan, bindings, qFlags) + schema, rowIter, qFlags, err = queryExec(sqlCtx, query, parsed, analyzedPlan, bindings, qFlags) if err != nil { sqlCtx.GetLogger().WithError(err).Warn("error running query") if verboseErrorLogging {