Skip to content

Commit 1e0e5e8

Browse files
authored
Merge pull request #2850 from dolthub/aaron/processlist-kill-on-prepare
processlist: Allow for killing the context associated with non-query operations like SetDB and Prepare.
2 parents b3a4c87 + 6887d52 commit 1e0e5e8

File tree

6 files changed

+138
-10
lines changed

6 files changed

+138
-10
lines changed

processlist.go

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,40 @@ func (pl *ProcessList) EndQuery(ctx *sql.Context) {
169169
}
170170
}
171171

172+
// Registers the process and session associated with |ctx| as performing
173+
// a long-running operation that should be able to be canceled with Kill.
174+
//
175+
// This is not used for Query processing --- the process is still in
176+
// CommandSleep, it does not have a QueryPid, etc. Must always be
177+
// bracketed with EndOperation(). Should certainly be used for any
178+
// Handler callbacks which may access the database, like Prepare.
179+
func (pl *ProcessList) BeginOperation(ctx *sql.Context) (*sql.Context, error) {
180+
pl.mu.Lock()
181+
defer pl.mu.Unlock()
182+
id := ctx.Session.ID()
183+
p := pl.procs[id]
184+
if p == nil {
185+
return nil, errors.New("internal error: connection not registered with process list")
186+
}
187+
if p.Kill != nil {
188+
return nil, errors.New("internal error: attempt to begin operation on connection which was already running one")
189+
}
190+
newCtx, cancel := ctx.NewSubContext()
191+
p.Kill = cancel
192+
return newCtx, nil
193+
}
194+
195+
func (pl *ProcessList) EndOperation(ctx *sql.Context) {
196+
pl.mu.Lock()
197+
defer pl.mu.Unlock()
198+
id := ctx.Session.ID()
199+
p := pl.procs[id]
200+
if p != nil && p.Kill != nil {
201+
p.Kill()
202+
p.Kill = nil
203+
}
204+
}
205+
172206
// UpdateTableProgress updates the progress of the table with the given name for the
173207
// process with the given pid.
174208
func (pl *ProcessList) UpdateTableProgress(pid uint64, name string, delta int64) {
@@ -322,7 +356,11 @@ func (pl *ProcessList) Kill(connID uint32) {
322356

323357
p := pl.procs[connID]
324358
if p != nil && p.Kill != nil {
325-
logrus.Infof("kill query: pid %d", p.QueryPid)
359+
if p.QueryPid != 0 {
360+
logrus.Infof("kill query: pid %d", p.QueryPid)
361+
} else {
362+
logrus.Infof("canceling context: connID %d", connID)
363+
}
326364
p.Kill()
327365
}
328366
}

processlist_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,59 @@ func TestKillConnection(t *testing.T) {
171171
require.False(t, killed[2])
172172
}
173173

174+
func TestBeginEndOperation(t *testing.T) {
175+
knownSession := sql.NewBaseSessionWithClientServer("", sql.Client{}, 1)
176+
unknownSession := sql.NewBaseSessionWithClientServer("", sql.Client{}, 2)
177+
178+
pl := NewProcessList()
179+
pl.AddConnection(1, "")
180+
181+
// Begining an operation with an unknown connection returns an error.
182+
ctx := sql.NewContext(context.Background(), sql.WithSession(unknownSession))
183+
_, err := pl.BeginOperation(ctx)
184+
require.Error(t, err)
185+
186+
// Can begin and end operation before connection is ready.
187+
ctx = sql.NewContext(context.Background(), sql.WithSession(knownSession))
188+
subCtx, err := pl.BeginOperation(ctx)
189+
require.NoError(t, err)
190+
pl.EndOperation(subCtx)
191+
192+
// Can begin and end operation across the connection ready boundary.
193+
subCtx, err = pl.BeginOperation(ctx)
194+
require.NoError(t, err)
195+
pl.ConnectionReady(knownSession)
196+
pl.EndOperation(subCtx)
197+
198+
// Ending the operation cancels the subcontext.
199+
subCtx, err = pl.BeginOperation(ctx)
200+
require.NoError(t, err)
201+
done := make(chan struct{})
202+
context.AfterFunc(subCtx, func() {
203+
close(done)
204+
})
205+
pl.EndOperation(subCtx)
206+
<-done
207+
208+
// Kill on the connection cancels the subcontext.
209+
subCtx, err = pl.BeginOperation(ctx)
210+
require.NoError(t, err)
211+
done = make(chan struct{})
212+
context.AfterFunc(subCtx, func() {
213+
close(done)
214+
})
215+
pl.Kill(1)
216+
<-done
217+
pl.EndOperation(subCtx)
218+
219+
// Beginning an operation while one is outstanding errors.
220+
subCtx, err = pl.BeginOperation(ctx)
221+
require.NoError(t, err)
222+
_, err = pl.BeginOperation(ctx)
223+
require.Error(t, err)
224+
pl.EndOperation(subCtx)
225+
}
226+
174227
// TestSlowQueryTracking tests that processes that take longer than @@long_query_time increment the
175228
// Slow_queries status variable.
176229
func TestSlowQueryTracking(t *testing.T) {

server/context.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,11 @@ func (s *SessionManager) SetDB(conn *mysql.Conn, dbName string) error {
139139
defer sql.SessionCommandEnd(sess)
140140

141141
ctx := sql.NewContext(context.Background(), sql.WithSession(sess))
142+
ctx, err = s.processlist.BeginOperation(ctx)
143+
if err != nil {
144+
return err
145+
}
146+
defer s.processlist.EndOperation(ctx)
142147
var db sql.Database
143148
if dbName != "" {
144149
db, err = s.getDbFunc(ctx, dbName)

server/handler.go

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ func (h *Handler) ConnectionAborted(_ *mysql.Conn, _ string) error {
102102
}
103103

104104
func (h *Handler) ComInitDB(c *mysql.Conn, schemaName string) error {
105+
// SetDB itself handles session and processlist operation lifecycle callbacks.
105106
err := h.sm.SetDB(c, schemaName)
106107
if err != nil {
107108
logrus.WithField("database", schemaName).Errorf("unable to process ComInitDB: %s", err.Error())
@@ -121,6 +122,11 @@ func (h *Handler) ComPrepare(ctx context.Context, c *mysql.Conn, query string, p
121122
if err != nil {
122123
return nil, err
123124
}
125+
sqlCtx, err = sqlCtx.ProcessList.BeginOperation(sqlCtx)
126+
if err != nil {
127+
return nil, err
128+
}
129+
defer sqlCtx.ProcessList.EndOperation(sqlCtx)
124130
err = sql.SessionCommandBegin(sqlCtx.Session)
125131
if err != nil {
126132
return nil, err
@@ -166,7 +172,11 @@ func (h *Handler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query str
166172
if err != nil {
167173
return nil, nil, err
168174
}
169-
175+
sqlCtx, err = sqlCtx.ProcessList.BeginOperation(sqlCtx)
176+
if err != nil {
177+
return nil, nil, err
178+
}
179+
defer sqlCtx.ProcessList.EndOperation(sqlCtx)
170180
err = sql.SessionCommandBegin(sqlCtx.Session)
171181
if err != nil {
172182
return nil, nil, err
@@ -201,6 +211,11 @@ func (h *Handler) ComBind(ctx context.Context, c *mysql.Conn, query string, pars
201211
if err != nil {
202212
return nil, nil, err
203213
}
214+
sqlCtx, err = sqlCtx.ProcessList.BeginOperation(sqlCtx)
215+
if err != nil {
216+
return nil, nil, err
217+
}
218+
defer sqlCtx.ProcessList.EndOperation(sqlCtx)
204219
err = sql.SessionCommandBegin(sqlCtx.Session)
205220
if err != nil {
206221
return nil, nil, err
@@ -395,6 +410,13 @@ func (h *Handler) doQuery(
395410
if err != nil {
396411
return "", err
397412
}
413+
// TODO: it would be nice to put this logic in the engine, not the handler, but we don't want the process to be
414+
// marked done until we're done spooling rows over the wire
415+
sqlCtx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query)
416+
if err != nil {
417+
return remainder, err
418+
}
419+
defer sqlCtx.ProcessList.EndQuery(sqlCtx)
398420
err = sql.SessionCommandBegin(sqlCtx.Session)
399421
if err != nil {
400422
return "", err
@@ -439,14 +461,6 @@ func (h *Handler) doQuery(
439461

440462
sqlCtx.GetLogger().Tracef("beginning execution")
441463

442-
// TODO: it would be nice to put this logic in the engine, not the handler, but we don't want the process to be
443-
// marked done until we're done spooling rows over the wire
444-
sqlCtx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query)
445-
if err != nil {
446-
return remainder, err
447-
}
448-
defer sqlCtx.ProcessList.EndQuery(sqlCtx)
449-
450464
var schema sql.Schema
451465
var rowIter sql.RowIter
452466
qFlags.Set(sql.QFlagDeferProjections)

server/handler_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,7 @@ func TestServerEventListener(t *testing.T) {
678678
require.Equal(listener.Disconnects, 2)
679679

680680
conn3 := newConn(3)
681+
handler.NewConnection(conn3)
681682
query := "SELECT ?"
682683
_, err = handler.ComPrepare(context.Background(), conn3, query, samplePrepareData)
683684
require.NoError(err)
@@ -1165,6 +1166,8 @@ func TestHandlerFoundRowsCapabilities(t *testing.T) {
11651166
),
11661167
}
11671168

1169+
handler.NewConnection(dummyConn)
1170+
11681171
tests := []struct {
11691172
name string
11701173
handler *Handler

sql/processlist.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,17 @@ type ProcessList interface {
4040
// EndQuery transitions a previously transitioned connection from Command "Query" to Command "Sleep".
4141
EndQuery(ctx *Context)
4242

43+
// BeginOperation registers and returns a SubContext for a
44+
// long-running operation on the conneciton which does not
45+
// change the process's Command state. This SubContext will be
46+
// killed by a call to |Kill|, and unregistered by a call to
47+
// |EndOperation|.
48+
BeginOperation(ctx *Context) (*Context, error)
49+
50+
// EndOperation cancels and deregisters the SubContext which
51+
// BeginOperation registered.
52+
EndOperation(ctx *Context)
53+
4354
// Kill terminates all queries for a given connection id
4455
Kill(connID uint32)
4556

@@ -166,6 +177,10 @@ func (e EmptyProcessList) BeginQuery(ctx *Context, query string) (*Context, erro
166177
return ctx, nil
167178
}
168179
func (e EmptyProcessList) EndQuery(ctx *Context) {}
180+
func (e EmptyProcessList) BeginOperation(ctx *Context) (*Context, error) {
181+
return ctx, nil
182+
}
183+
func (e EmptyProcessList) EndOperation(ctx *Context) {}
169184

170185
func (e EmptyProcessList) Kill(connID uint32) {}
171186
func (e EmptyProcessList) Done(pid uint64) {}

0 commit comments

Comments
 (0)