Skip to content

Commit e992aa5

Browse files
committed
processlist: Allow for killing the context associated with non-query operations like SetDB and Prepare.
The processlist maintains a Process struct for a given mysql.Conn, and registers a context.CancelFunc for the running operation when the handler dispatches it. This works for queries today. This PR makes it so we also register CancelFunc callbacks for operations which touch the database but do not put the connection into CommandQuery, such as Prepare and ComInit.
1 parent b3a4c87 commit e992aa5

File tree

5 files changed

+85
-10
lines changed

5 files changed

+85
-10
lines changed

processlist.go

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,40 @@ func (pl *ProcessList) EndQuery(ctx *sql.Context) {
169169
}
170170
}
171171

172+
// Registers the process and session associated with |ctx| as performing
173+
// a long-running operation that should be able to be canceled with Kill.
174+
//
175+
// This is not used for Query processing --- the process is still in
176+
// CommandSleep, it does not have a QueryPid, etc. Must always be
177+
// bracketed with EndOperation(). Should certainly be used for any
178+
// Handler callbacks which may access the database, like Prepare.
179+
func (pl *ProcessList) BeginOperation(ctx *sql.Context) (*sql.Context, error) {
180+
pl.mu.Lock()
181+
defer pl.mu.Unlock()
182+
id := ctx.Session.ID()
183+
p := pl.procs[id]
184+
if p == nil {
185+
return nil, errors.New("internal error: connection not registered with process list")
186+
}
187+
if p.Kill != nil {
188+
return nil, errors.New("internal error: attempt to begin operation on connection which was already running one")
189+
}
190+
newCtx, cancel := ctx.NewSubContext()
191+
p.Kill = cancel
192+
return newCtx, nil
193+
}
194+
195+
func (pl *ProcessList) EndOperation(ctx *sql.Context) {
196+
pl.mu.Lock()
197+
defer pl.mu.Unlock()
198+
id := ctx.Session.ID()
199+
p := pl.procs[id]
200+
if p != nil && p.Kill != nil {
201+
p.Kill()
202+
p.Kill = nil
203+
}
204+
}
205+
172206
// UpdateTableProgress updates the progress of the table with the given name for the
173207
// process with the given pid.
174208
func (pl *ProcessList) UpdateTableProgress(pid uint64, name string, delta int64) {
@@ -322,7 +356,11 @@ func (pl *ProcessList) Kill(connID uint32) {
322356

323357
p := pl.procs[connID]
324358
if p != nil && p.Kill != nil {
325-
logrus.Infof("kill query: pid %d", p.QueryPid)
359+
if p.QueryPid != 0 {
360+
logrus.Infof("kill query: pid %d", p.QueryPid)
361+
} else {
362+
logrus.Infof("canceling context: connID %d", connID)
363+
}
326364
p.Kill()
327365
}
328366
}

server/context.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,11 @@ func (s *SessionManager) SetDB(conn *mysql.Conn, dbName string) error {
139139
defer sql.SessionCommandEnd(sess)
140140

141141
ctx := sql.NewContext(context.Background(), sql.WithSession(sess))
142+
ctx, err = s.processlist.BeginOperation(ctx)
143+
if err != nil {
144+
return err
145+
}
146+
defer s.processlist.EndOperation(ctx)
142147
var db sql.Database
143148
if dbName != "" {
144149
db, err = s.getDbFunc(ctx, dbName)

server/handler.go

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ func (h *Handler) ConnectionAborted(_ *mysql.Conn, _ string) error {
102102
}
103103

104104
func (h *Handler) ComInitDB(c *mysql.Conn, schemaName string) error {
105+
// SetDB itself handles session and processlist operation lifecycle callbacks.
105106
err := h.sm.SetDB(c, schemaName)
106107
if err != nil {
107108
logrus.WithField("database", schemaName).Errorf("unable to process ComInitDB: %s", err.Error())
@@ -121,6 +122,11 @@ func (h *Handler) ComPrepare(ctx context.Context, c *mysql.Conn, query string, p
121122
if err != nil {
122123
return nil, err
123124
}
125+
sqlCtx, err = sqlCtx.ProcessList.BeginOperation(sqlCtx)
126+
if err != nil {
127+
return nil, err
128+
}
129+
defer sqlCtx.ProcessList.EndOperation(sqlCtx)
124130
err = sql.SessionCommandBegin(sqlCtx.Session)
125131
if err != nil {
126132
return nil, err
@@ -166,7 +172,11 @@ func (h *Handler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query str
166172
if err != nil {
167173
return nil, nil, err
168174
}
169-
175+
sqlCtx, err = sqlCtx.ProcessList.BeginOperation(sqlCtx)
176+
if err != nil {
177+
return nil, nil, err
178+
}
179+
defer sqlCtx.ProcessList.EndOperation(sqlCtx)
170180
err = sql.SessionCommandBegin(sqlCtx.Session)
171181
if err != nil {
172182
return nil, nil, err
@@ -201,6 +211,11 @@ func (h *Handler) ComBind(ctx context.Context, c *mysql.Conn, query string, pars
201211
if err != nil {
202212
return nil, nil, err
203213
}
214+
sqlCtx, err = sqlCtx.ProcessList.BeginOperation(sqlCtx)
215+
if err != nil {
216+
return nil, nil, err
217+
}
218+
defer sqlCtx.ProcessList.EndOperation(sqlCtx)
204219
err = sql.SessionCommandBegin(sqlCtx.Session)
205220
if err != nil {
206221
return nil, nil, err
@@ -395,6 +410,13 @@ func (h *Handler) doQuery(
395410
if err != nil {
396411
return "", err
397412
}
413+
// TODO: it would be nice to put this logic in the engine, not the handler, but we don't want the process to be
414+
// marked done until we're done spooling rows over the wire
415+
sqlCtx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query)
416+
if err != nil {
417+
return remainder, err
418+
}
419+
defer sqlCtx.ProcessList.EndQuery(sqlCtx)
398420
err = sql.SessionCommandBegin(sqlCtx.Session)
399421
if err != nil {
400422
return "", err
@@ -439,14 +461,6 @@ func (h *Handler) doQuery(
439461

440462
sqlCtx.GetLogger().Tracef("beginning execution")
441463

442-
// TODO: it would be nice to put this logic in the engine, not the handler, but we don't want the process to be
443-
// marked done until we're done spooling rows over the wire
444-
sqlCtx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query)
445-
if err != nil {
446-
return remainder, err
447-
}
448-
defer sqlCtx.ProcessList.EndQuery(sqlCtx)
449-
450464
var schema sql.Schema
451465
var rowIter sql.RowIter
452466
qFlags.Set(sql.QFlagDeferProjections)

server/handler_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,7 @@ func TestServerEventListener(t *testing.T) {
678678
require.Equal(listener.Disconnects, 2)
679679

680680
conn3 := newConn(3)
681+
handler.NewConnection(conn3)
681682
query := "SELECT ?"
682683
_, err = handler.ComPrepare(context.Background(), conn3, query, samplePrepareData)
683684
require.NoError(err)
@@ -1165,6 +1166,8 @@ func TestHandlerFoundRowsCapabilities(t *testing.T) {
11651166
),
11661167
}
11671168

1169+
handler.NewConnection(dummyConn)
1170+
11681171
tests := []struct {
11691172
name string
11701173
handler *Handler

sql/processlist.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,17 @@ type ProcessList interface {
4040
// EndQuery transitions a previously transitioned connection from Command "Query" to Command "Sleep".
4141
EndQuery(ctx *Context)
4242

43+
// BeginOperation registers and returns a SubContext for a
44+
// long-running operation on the conneciton which does not
45+
// change the process's Command state. This SubContext will be
46+
// killed by a call to |Kill|, and unregistered by a call to
47+
// |EndOperation|.
48+
BeginOperation(ctx *Context) (*Context, error)
49+
50+
// EndOperation cancels and deregisters the SubContext which
51+
// BeginOperation registered.
52+
EndOperation(ctx *Context)
53+
4354
// Kill terminates all queries for a given connection id
4455
Kill(connID uint32)
4556

@@ -166,6 +177,10 @@ func (e EmptyProcessList) BeginQuery(ctx *Context, query string) (*Context, erro
166177
return ctx, nil
167178
}
168179
func (e EmptyProcessList) EndQuery(ctx *Context) {}
180+
func (e EmptyProcessList) BeginOperation(ctx *Context) (*Context, error) {
181+
return ctx, nil
182+
}
183+
func (e EmptyProcessList) EndOperation(ctx *Context) {}
169184

170185
func (e EmptyProcessList) Kill(connID uint32) {}
171186
func (e EmptyProcessList) Done(pid uint64) {}

0 commit comments

Comments
 (0)