Skip to content

Commit 725da3b

Browse files
committed
Merge branch 'main' into zachmu/enginetests5
2 parents df523a7 + 2c6e3dd commit 725da3b

File tree

15 files changed

+2903
-2960
lines changed

15 files changed

+2903
-2960
lines changed

engine.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ func (e *Engine) QueryWithBindings(ctx *sql.Context, query string, parsed sqlpar
454454

455455
// PrepQueryPlanForExecution prepares a query plan for execution and returns the result schema with a row iterator to
456456
// begin spooling results
457-
func (e *Engine) PrepQueryPlanForExecution(ctx *sql.Context, _ string, plan sql.Node) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
457+
func (e *Engine) PrepQueryPlanForExecution(ctx *sql.Context, _ string, plan sql.Node, qFlags *sql.QueryFlags) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
458458
// Give the integrator a chance to reject the session before proceeding
459459
// TODO: this check doesn't belong here
460460
err := ctx.Session.ValidateSession(ctx)
@@ -482,9 +482,9 @@ func (e *Engine) PrepQueryPlanForExecution(ctx *sql.Context, _ string, plan sql.
482482
return nil, nil, nil, err
483483
}
484484

485-
iter = finalizeIters(ctx, plan, nil, iter)
485+
iter = finalizeIters(ctx, plan, qFlags, iter)
486486

487-
return plan.Schema(), iter, nil, nil
487+
return plan.Schema(), iter, qFlags, nil
488488
}
489489

490490
// BoundQueryPlan returns query plan for the given statement with the given bindings applied
@@ -866,6 +866,7 @@ func findCreateEventNode(planTree sql.Node) (*plan.CreateEvent, error) {
866866

867867
// finalizeIters applies the final transformations on sql.RowIter before execution.
868868
func finalizeIters(ctx *sql.Context, analyzed sql.Node, qFlags *sql.QueryFlags, iter sql.RowIter) sql.RowIter {
869+
iter = rowexec.AddTriggerRollbackIter(ctx, qFlags, iter)
869870
iter = rowexec.AddTransactionCommittingIter(qFlags, iter)
870871
iter = plan.AddTrackedRowIter(ctx, analyzed, iter)
871872
iter = rowexec.AddExpressionCloser(analyzed, iter)

enginetest/queries/integration_plans.go

Lines changed: 2728 additions & 2743 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

server/handler.go

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -372,15 +372,15 @@ func (h *Handler) doQuery(
372372
bindings map[string]*querypb.BindVariable,
373373
callback func(*sqltypes.Result, bool) error,
374374
qFlags *sql.QueryFlags,
375-
) (string, error) {
376-
sqlCtx, err := h.sm.NewContext(ctx, c, query)
375+
) (remainder string, err error) {
376+
var sqlCtx *sql.Context
377+
sqlCtx, err = h.sm.NewContext(ctx, c, query)
377378
if err != nil {
378379
return "", err
379380
}
380381

381382
start := time.Now()
382383

383-
var remainder string
384384
var prequery string
385385
if parsed == nil {
386386
_, inPreparedCache := h.e.PreparedDataCache.GetCachedStmt(sqlCtx.Session.ID(), query)
@@ -411,23 +411,24 @@ func (h *Handler) doQuery(
411411
sqlCtx.GetLogger().Debugf("Starting query")
412412

413413
finish := observeQuery(sqlCtx, query)
414-
defer finish(err)
414+
defer func() {
415+
finish(err)
416+
}()
415417

416418
sqlCtx.GetLogger().Tracef("beginning execution")
417419

418-
oCtx := ctx
419-
420420
// TODO: it would be nice to put this logic in the engine, not the handler, but we don't want the process to be
421421
// marked done until we're done spooling rows over the wire
422-
ctx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query)
423-
defer func() {
424-
if err != nil && ctx != nil {
425-
sqlCtx.ProcessList.EndQuery(sqlCtx)
426-
}
427-
}()
422+
sqlCtx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query)
423+
if err != nil {
424+
return remainder, err
425+
}
426+
defer sqlCtx.ProcessList.EndQuery(sqlCtx)
428427

428+
var schema sql.Schema
429+
var rowIter sql.RowIter
429430
qFlags.Set(sql.QFlagDeferProjections)
430-
schema, rowIter, qFlags, err := queryExec(sqlCtx, query, parsed, analyzedPlan, bindings, qFlags)
431+
schema, rowIter, qFlags, err = queryExec(sqlCtx, query, parsed, analyzedPlan, bindings, qFlags)
431432
if err != nil {
432433
sqlCtx.GetLogger().WithError(err).Warn("error running query")
433434
if verboseErrorLogging {
@@ -455,9 +456,6 @@ func (h *Handler) doQuery(
455456
return remainder, err
456457
}
457458

458-
// errGroup context is now canceled
459-
ctx = oCtx
460-
461459
if err = setConnStatusFlags(sqlCtx, c); err != nil {
462460
return remainder, err
463461
}
@@ -1084,9 +1082,9 @@ func (h *Handler) executeBoundPlan(
10841082
_ sqlparser.Statement,
10851083
plan sql.Node,
10861084
_ map[string]*querypb.BindVariable,
1087-
_ *sql.QueryFlags,
1085+
qFlags *sql.QueryFlags,
10881086
) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
1089-
return h.e.PrepQueryPlanForExecution(ctx, query, plan)
1087+
return h.e.PrepQueryPlanForExecution(ctx, query, plan, qFlags)
10901088
}
10911089

10921090
func bindingsToExprs(bindings map[string]*querypb.BindVariable) (map[string]sqlparser.Expr, error) {

server/handler_test.go

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"time"
2626

2727
"github.com/dolthub/vitess/go/mysql"
28+
"github.com/dolthub/vitess/go/race"
2829
"github.com/dolthub/vitess/go/sqltypes"
2930
"github.com/dolthub/vitess/go/vt/proto/query"
3031
"github.com/stretchr/testify/assert"
@@ -742,6 +743,113 @@ func TestHandlerKill(t *testing.T) {
742743
require.Len(handler.sm.sessions, 1)
743744
}
744745

746+
func TestHandlerKillQuery(t *testing.T) {
747+
if race.Enabled {
748+
t.Skip("this test is inherently racey")
749+
}
750+
require := require.New(t)
751+
e, pro := setupMemDB(require)
752+
dbFunc := pro.Database
753+
754+
handler := &Handler{
755+
e: e,
756+
sm: NewSessionManager(
757+
func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
758+
return sql.NewBaseSessionWithClientServer(addr, sql.Client{Capabilities: conn.Capabilities}, conn.ConnectionID), nil
759+
},
760+
sql.NoopTracer,
761+
dbFunc,
762+
e.MemoryManager,
763+
e.ProcessList,
764+
"foo",
765+
),
766+
}
767+
768+
var err error
769+
conn1 := newConn(1)
770+
handler.NewConnection(conn1)
771+
772+
conn2 := newConn(2)
773+
handler.NewConnection(conn2)
774+
775+
require.Len(handler.sm.connections, 2)
776+
require.Len(handler.sm.sessions, 0)
777+
778+
handler.ComInitDB(conn1, "test")
779+
err = handler.sm.SetDB(conn1, "test")
780+
require.NoError(err)
781+
782+
err = handler.sm.SetDB(conn2, "test")
783+
require.NoError(err)
784+
785+
require.False(conn1.Conn.(*mockConn).closed)
786+
require.False(conn2.Conn.(*mockConn).closed)
787+
require.Len(handler.sm.connections, 2)
788+
require.Len(handler.sm.sessions, 2)
789+
790+
var wg sync.WaitGroup
791+
wg.Add(1)
792+
sleepQuery := "SELECT SLEEP(1)"
793+
go func() {
794+
defer wg.Done()
795+
err = handler.ComQuery(context.Background(), conn1, sleepQuery, func(res *sqltypes.Result, more bool) error {
796+
return nil
797+
})
798+
require.Error(err)
799+
}()
800+
801+
time.Sleep(100 * time.Millisecond)
802+
var sleepQueryID string
803+
err = handler.ComQuery(context.Background(), conn2, "SHOW PROCESSLIST", func(res *sqltypes.Result, more bool) error {
804+
// 1, , , test, Query, 0, ... , SELECT SLEEP(1000)
805+
// 2, , , test, Query, 0, running, SHOW PROCESSLIST
806+
require.Equal(2, len(res.Rows))
807+
hasSleepQuery := false
808+
for _, row := range res.Rows {
809+
if row[7].ToString() != sleepQuery {
810+
continue
811+
}
812+
hasSleepQuery = true
813+
sleepQueryID = row[0].ToString()
814+
require.Equal("Query", row[4].ToString())
815+
}
816+
require.True(hasSleepQuery)
817+
return nil
818+
})
819+
require.NoError(err)
820+
821+
time.Sleep(100 * time.Millisecond)
822+
err = handler.ComQuery(context.Background(), conn2, "KILL QUERY "+sleepQueryID, func(res *sqltypes.Result, more bool) error {
823+
return nil
824+
})
825+
require.NoError(err)
826+
wg.Wait()
827+
828+
time.Sleep(100 * time.Millisecond)
829+
err = handler.ComQuery(context.Background(), conn2, "SHOW PROCESSLIST", func(res *sqltypes.Result, more bool) error {
830+
// 1, , , test, Sleep, 0, ,
831+
// 2, , , test, Query, 0, running, SHOW PROCESSLIST
832+
require.Equal(2, len(res.Rows))
833+
hasSleepQueryID := false
834+
for _, row := range res.Rows {
835+
if row[0].ToString() != sleepQueryID {
836+
continue
837+
}
838+
hasSleepQueryID = true
839+
require.Equal("Sleep", row[4].ToString())
840+
require.Equal("", row[7].ToString())
841+
}
842+
require.True(hasSleepQueryID)
843+
return nil
844+
})
845+
require.NoError(err)
846+
847+
require.False(conn1.Conn.(*mockConn).closed)
848+
require.False(conn2.Conn.(*mockConn).closed)
849+
require.Len(handler.sm.connections, 2)
850+
require.Len(handler.sm.sessions, 2)
851+
}
852+
745853
func TestSchemaToFields(t *testing.T) {
746854
require := require.New(t)
747855

sql/analyzer/rule_ids.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ const (
6868
assignRoutinesId // assignRoutines
6969
modifyUpdateExprsForJoinId // modifyUpdateExprsForJoin
7070
applyUpdateAccumulatorsId // applyUpdateAccumulators
71-
wrapWithRollbackId // wrapWithRollback
7271
applyForeignKeysId // applyForeignKeys
7372

7473
// validate

sql/analyzer/ruleid_string.go

Lines changed: 17 additions & 18 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sql/analyzer/rules.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ func init() {
2222
{applyTriggersId, applyTriggers},
2323
{applyProceduresId, applyProcedures},
2424
{applyUpdateAccumulatorsId, applyUpdateAccumulators},
25-
{wrapWithRollbackId, wrapWithRollback},
2625
{inlineSubqueryAliasRefsId, inlineSubqueryAliasRefs},
2726
{cacheSubqueryAliasesInJoinsId, cacheSubqueryAliasesInJoins},
2827
{BacktickDefaulColumnValueNamesId, backtickDefaultColumnValueNames},

sql/analyzer/triggers.go

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope
346346

347347
switch n := c.Node.(type) {
348348
case *plan.InsertInto:
349+
qFlags.Set(sql.QFlagTrigger)
349350
if trigger.TriggerTime == sqlparser.BeforeStr {
350351
triggerExecutor := plan.NewTriggerExecutor(n.Source, triggerLogic, plan.InsertTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{
351352
Name: trigger.TriggerName,
@@ -359,6 +360,7 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope
359360
}), transform.NewTree, nil
360361
}
361362
case *plan.Update:
363+
qFlags.Set(sql.QFlagTrigger)
362364
if trigger.TriggerTime == sqlparser.BeforeStr {
363365
triggerExecutor := plan.NewTriggerExecutor(n.Child, triggerLogic, plan.UpdateTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{
364366
Name: trigger.TriggerName,
@@ -387,6 +389,7 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope
387389
"does not support triggers; retry with single table deletes")
388390
}
389391

392+
qFlags.Set(sql.QFlagTrigger)
390393
if trigger.TriggerTime == sqlparser.BeforeStr {
391394
triggerExecutor := plan.NewTriggerExecutor(n.Child, triggerLogic, plan.DeleteTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{
392395
Name: trigger.TriggerName,
@@ -517,41 +520,3 @@ func orderTriggersAndReverseAfter(triggers []*plan.CreateTrigger) []*plan.Create
517520
func triggerEventsMatch(event plan.TriggerEvent, event2 string) bool {
518521
return strings.ToLower((string)(event)) == strings.ToLower(event2)
519522
}
520-
521-
// wrapWithRollback wraps the entire tree iff it contains a trigger, allowing rollback when a trigger errors
522-
func wrapWithRollback(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
523-
// Check if tree contains a TriggerExecutor
524-
containsTrigger := false
525-
transform.Inspect(n, func(n sql.Node) bool {
526-
// After Triggers wrap nodes
527-
if _, ok := n.(*plan.TriggerExecutor); ok {
528-
containsTrigger = true
529-
return false // done, don't bother to recurse
530-
}
531-
532-
// Before Triggers on Inserts are inside Source
533-
if n, ok := n.(*plan.InsertInto); ok {
534-
if _, ok := n.Source.(*plan.TriggerExecutor); ok {
535-
containsTrigger = true
536-
return false
537-
}
538-
}
539-
540-
// Before Triggers on Delete and Update should be in children
541-
return true
542-
})
543-
544-
// No TriggerExecutor, so return same tree
545-
if !containsTrigger {
546-
return n, transform.SameTree, nil
547-
}
548-
549-
// If we don't have a transaction session we can't do rollbacks
550-
_, ok := ctx.Session.(sql.TransactionSession)
551-
if !ok {
552-
return plan.NewNoopTriggerRollback(n), transform.NewTree, nil
553-
}
554-
555-
// Wrap tree with new node
556-
return plan.NewTriggerRollback(n), transform.NewTree, nil
557-
}

0 commit comments

Comments
 (0)