Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ func (e *Engine) QueryWithBindings(ctx *sql.Context, query string, parsed sqlpar

// PrepQueryPlanForExecution prepares a query plan for execution and returns the result schema with a row iterator to
// begin spooling results
func (e *Engine) PrepQueryPlanForExecution(ctx *sql.Context, _ string, plan sql.Node) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
func (e *Engine) PrepQueryPlanForExecution(ctx *sql.Context, _ string, plan sql.Node, qFlags *sql.QueryFlags) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
// Give the integrator a chance to reject the session before proceeding
// TODO: this check doesn't belong here
err := ctx.Session.ValidateSession(ctx)
Expand Down Expand Up @@ -482,9 +482,9 @@ func (e *Engine) PrepQueryPlanForExecution(ctx *sql.Context, _ string, plan sql.
return nil, nil, nil, err
}

iter = finalizeIters(ctx, plan, nil, iter)
iter = finalizeIters(ctx, plan, qFlags, iter)

return plan.Schema(), iter, nil, nil
return plan.Schema(), iter, qFlags, nil
}

// BoundQueryPlan returns query plan for the given statement with the given bindings applied
Expand Down Expand Up @@ -866,6 +866,7 @@ func findCreateEventNode(planTree sql.Node) (*plan.CreateEvent, error) {

// finalizeIters applies the final transformations on sql.RowIter before execution.
func finalizeIters(ctx *sql.Context, analyzed sql.Node, qFlags *sql.QueryFlags, iter sql.RowIter) sql.RowIter {
iter = rowexec.AddTriggerRollbackIter(ctx, qFlags, iter)
iter = rowexec.AddTransactionCommittingIter(qFlags, iter)
iter = plan.AddTrackedRowIter(ctx, analyzed, iter)
iter = rowexec.AddExpressionCloser(analyzed, iter)
Expand Down
5,471 changes: 2,728 additions & 2,743 deletions enginetest/queries/integration_plans.go

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -1084,9 +1084,9 @@ func (h *Handler) executeBoundPlan(
_ sqlparser.Statement,
plan sql.Node,
_ map[string]*querypb.BindVariable,
_ *sql.QueryFlags,
qFlags *sql.QueryFlags,
) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
return h.e.PrepQueryPlanForExecution(ctx, query, plan)
return h.e.PrepQueryPlanForExecution(ctx, query, plan, qFlags)
}

func bindingsToExprs(bindings map[string]*querypb.BindVariable) (map[string]sqlparser.Expr, error) {
Expand Down
1 change: 0 additions & 1 deletion sql/analyzer/rule_ids.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ const (
assignRoutinesId // assignRoutines
modifyUpdateExprsForJoinId // modifyUpdateExprsForJoin
applyUpdateAccumulatorsId // applyUpdateAccumulators
wrapWithRollbackId // wrapWithRollback
applyForeignKeysId // applyForeignKeys

// validate
Expand Down
35 changes: 17 additions & 18 deletions sql/analyzer/ruleid_string.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion sql/analyzer/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ func init() {
{applyTriggersId, applyTriggers},
{applyProceduresId, applyProcedures},
{applyUpdateAccumulatorsId, applyUpdateAccumulators},
{wrapWithRollbackId, wrapWithRollback},
{inlineSubqueryAliasRefsId, inlineSubqueryAliasRefs},
{cacheSubqueryAliasesInJoinsId, cacheSubqueryAliasesInJoins},
{BacktickDefaulColumnValueNamesId, backtickDefaultColumnValueNames},
Expand Down
41 changes: 3 additions & 38 deletions sql/analyzer/triggers.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope

switch n := c.Node.(type) {
case *plan.InsertInto:
qFlags.Set(sql.QFlagTrigger)
if trigger.TriggerTime == sqlparser.BeforeStr {
triggerExecutor := plan.NewTriggerExecutor(n.Source, triggerLogic, plan.InsertTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{
Name: trigger.TriggerName,
Expand All @@ -359,6 +360,7 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope
}), transform.NewTree, nil
}
case *plan.Update:
qFlags.Set(sql.QFlagTrigger)
if trigger.TriggerTime == sqlparser.BeforeStr {
triggerExecutor := plan.NewTriggerExecutor(n.Child, triggerLogic, plan.UpdateTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{
Name: trigger.TriggerName,
Expand Down Expand Up @@ -387,6 +389,7 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope
"does not support triggers; retry with single table deletes")
}

qFlags.Set(sql.QFlagTrigger)
if trigger.TriggerTime == sqlparser.BeforeStr {
triggerExecutor := plan.NewTriggerExecutor(n.Child, triggerLogic, plan.DeleteTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{
Name: trigger.TriggerName,
Expand Down Expand Up @@ -517,41 +520,3 @@ func orderTriggersAndReverseAfter(triggers []*plan.CreateTrigger) []*plan.Create
func triggerEventsMatch(event plan.TriggerEvent, event2 string) bool {
return strings.ToLower((string)(event)) == strings.ToLower(event2)
}

// wrapWithRollback wraps the entire tree iff it contains a trigger, allowing rollback when a trigger errors
func wrapWithRollback(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
// Check if tree contains a TriggerExecutor
containsTrigger := false
transform.Inspect(n, func(n sql.Node) bool {
// After Triggers wrap nodes
if _, ok := n.(*plan.TriggerExecutor); ok {
containsTrigger = true
return false // done, don't bother to recurse
}

// Before Triggers on Inserts are inside Source
if n, ok := n.(*plan.InsertInto); ok {
if _, ok := n.Source.(*plan.TriggerExecutor); ok {
containsTrigger = true
return false
}
}

// Before Triggers on Delete and Update should be in children
return true
})

// No TriggerExecutor, so return same tree
if !containsTrigger {
return n, transform.SameTree, nil
}

// If we don't have a transaction session we can't do rollbacks
_, ok := ctx.Session.(sql.TransactionSession)
if !ok {
return plan.NewNoopTriggerRollback(n), transform.NewTree, nil
}

// Wrap tree with new node
return plan.NewTriggerRollback(n), transform.NewTree, nil
}
100 changes: 0 additions & 100 deletions sql/plan/trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,103 +103,3 @@ func (t *TriggerExecutor) CheckPrivileges(ctx *sql.Context, opChecker sql.Privil
func (t *TriggerExecutor) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.GetCoercibility(ctx, t.left)
}

// TriggerRollback is a node that wraps the entire tree iff it contains a trigger, creates a savepoint, and performs a
// rollback if something went wrong during execution
type TriggerRollback struct {
UnaryNode
}

var _ sql.Node = (*TriggerRollback)(nil)
var _ sql.CollationCoercible = (*TriggerRollback)(nil)

func NewTriggerRollback(child sql.Node) *TriggerRollback {
return &TriggerRollback{
UnaryNode: UnaryNode{Child: child},
}
}

func (t *TriggerRollback) WithChildren(children ...sql.Node) (sql.Node, error) {
if len(children) != 1 {
return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1)
}

return NewTriggerRollback(children[0]), nil
}

// CheckPrivileges implements the interface sql.Node.
func (t *TriggerRollback) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
return t.Child.CheckPrivileges(ctx, opChecker)
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (t *TriggerRollback) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.GetCoercibility(ctx, t.Child)
}

func (t *TriggerRollback) IsReadOnly() bool {
return t.Child.IsReadOnly()
}

func (t *TriggerRollback) String() string {
pr := sql.NewTreePrinter()
_ = pr.WriteNode("TriggerRollback()")
_ = pr.WriteChildren(t.Child.String())
return pr.String()
}

func (t *TriggerRollback) DebugString() string {
pr := sql.NewTreePrinter()
_ = pr.WriteNode("TriggerRollback")
_ = pr.WriteChildren(sql.DebugString(t.Child))
return pr.String()
}

type NoopTriggerRollback struct {
UnaryNode
}

var _ sql.Node = (*NoopTriggerRollback)(nil)
var _ sql.CollationCoercible = (*NoopTriggerRollback)(nil)

func NewNoopTriggerRollback(child sql.Node) *NoopTriggerRollback {
return &NoopTriggerRollback{
UnaryNode: UnaryNode{Child: child},
}
}

func (t *NoopTriggerRollback) WithChildren(children ...sql.Node) (sql.Node, error) {
if len(children) != 1 {
return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1)
}

return NewNoopTriggerRollback(children[0]), nil
}

// CheckPrivileges implements the interface sql.Node.
func (t *NoopTriggerRollback) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
return t.Child.CheckPrivileges(ctx, opChecker)
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (t *NoopTriggerRollback) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.GetCoercibility(ctx, t.Child)
}

func (t *NoopTriggerRollback) IsReadOnly() bool {
return true
}

func (t *NoopTriggerRollback) String() string {
pr := sql.NewTreePrinter()
_ = pr.WriteNode("TriggerRollback()")
_ = pr.WriteChildren(t.Child.String())
return pr.String()
}

func (t *NoopTriggerRollback) DebugString() string {
pr := sql.NewTreePrinter()
_ = pr.WriteNode("TriggerRollback")
_ = pr.WriteChildren(sql.DebugString(t.Child))
return pr.String()
}
4 changes: 4 additions & 0 deletions sql/query_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ const (
QFlagDeferProjections
// QFlagUndeferrableExprs indicates that the query has expressions that cannot be deferred
QFlagUndeferrableExprs
QFlagTrigger
)

type QueryFlags struct {
Expand All @@ -69,6 +70,9 @@ func (qp *QueryFlags) Unset(flag int) {
}

func (qp *QueryFlags) IsSet(flag int) bool {
if qp == nil {
return false
}
return qp.Flags.Contains(flag)
}

Expand Down
3 changes: 1 addition & 2 deletions sql/rowexec/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ type ExecBuilderFunc func(ctx *sql.Context, n sql.Node, r sql.Row) (sql.RowIter,
// sql.ExecSourceRel are also built into the tree.
type BaseBuilder struct {
// if override is provided, we try to build executor with this first
override sql.NodeExecBuilder
triggerSavePointCounter int // tracks the number of save points that have been created by triggers
override sql.NodeExecBuilder
}

func (b *BaseBuilder) Build(ctx *sql.Context, n sql.Node, r sql.Row) (sql.RowIter, error) {
Expand Down
26 changes: 0 additions & 26 deletions sql/rowexec/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,32 +256,6 @@ func (b *BaseBuilder) buildDropTable(ctx *sql.Context, n *plan.DropTable, row sq
return rowIterWithOkResultWithZeroRowsAffected(), nil
}

func (b *BaseBuilder) buildTriggerRollback(ctx *sql.Context, n *plan.TriggerRollback, row sql.Row) (sql.RowIter, error) {
childIter, err := b.buildNodeExec(ctx, n.Child, row)
if err != nil {
return nil, err
}

savePointCounter := b.triggerSavePointCounter + 1
savePointName := fmt.Sprintf("%s%v", TriggerSavePointPrefix, savePointCounter)
ctx.GetLogger().Tracef("TriggerRollback creating savepoint: %s", savePointName)

ts, ok := ctx.Session.(sql.TransactionSession)
if !ok {
return nil, fmt.Errorf("expected a sql.TransactionSession, but got %T", ctx.Session)
}

if err := ts.CreateSavepoint(ctx, ctx.GetTransaction(), savePointName); err != nil {
ctx.GetLogger().WithError(err).Errorf("CreateSavepoint failed")
}
b.triggerSavePointCounter = savePointCounter

return &triggerRollbackIter{
child: childIter,
savePointName: savePointName,
}, nil
}

func (b *BaseBuilder) buildAlterIndex(ctx *sql.Context, n *plan.AlterIndex, row sql.Row) (sql.RowIter, error) {
err := b.executeAlterIndex(ctx, n)
if err != nil {
Expand Down
23 changes: 22 additions & 1 deletion sql/rowexec/dml_iters.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2023 Dolthub, Inc.
// Copyright 2023-2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,6 +33,27 @@ type triggerRollbackIter struct {
savePointName string
}

func AddTriggerRollbackIter(ctx *sql.Context, qFlags *sql.QueryFlags, iter sql.RowIter) sql.RowIter {
if !qFlags.IsSet(sql.QFlagTrigger) {
return iter
}

transSess, isTransSess := ctx.Session.(sql.TransactionSession)
if !isTransSess {
return iter
}

ctx.GetLogger().Tracef("TriggerRollback creating savepoint: %s", TriggerSavePointPrefix)
if err := transSess.CreateSavepoint(ctx, ctx.GetTransaction(), TriggerSavePointPrefix); err != nil {
ctx.GetLogger().WithError(err).Errorf("CreateSavepoint failed")
}

return &triggerRollbackIter{
child: iter,
savePointName: TriggerSavePointPrefix,
}
}

func (t *triggerRollbackIter) Next(ctx *sql.Context) (row sql.Row, returnErr error) {
childRow, err := t.child.Next(ctx)

Expand Down
4 changes: 0 additions & 4 deletions sql/rowexec/node_builder.gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,6 @@ func (b *BaseBuilder) buildNodeExecNoAnalyze(ctx *sql.Context, n sql.Node, row s
return b.buildHaving(ctx, n, row)
case *plan.Signal:
return b.buildSignal(ctx, n, row)
case *plan.TriggerRollback:
return b.buildTriggerRollback(ctx, n, row)
case *plan.ExternalProcedure:
return b.buildExternalProcedure(ctx, n, row)
case *plan.Into:
Expand Down Expand Up @@ -246,8 +244,6 @@ func (b *BaseBuilder) buildNodeExecNoAnalyze(ctx *sql.Context, n sql.Node, row s
return b.buildCreateIndex(ctx, n, row)
case *plan.Procedure:
return b.buildProcedure(ctx, n, row)
case *plan.NoopTriggerRollback:
return b.buildNoopTriggerRollback(ctx, n, row)
case *plan.With:
return b.buildWith(ctx, n, row)
case *plan.Project:
Expand Down
5 changes: 0 additions & 5 deletions sql/rowexec/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,6 @@ func (b *BaseBuilder) buildCommit(ctx *sql.Context, n *plan.Commit, row sql.Row)
return sql.RowsToRowIter(), nil
}

func (b *BaseBuilder) buildNoopTriggerRollback(ctx *sql.Context, n *plan.NoopTriggerRollback, row sql.Row) (sql.RowIter, error) {
return b.buildNodeExec(ctx, n.Child, row)

}

func (b *BaseBuilder) buildKill(ctx *sql.Context, n *plan.Kill, row sql.Row) (sql.RowIter, error) {
return &lazyRowIter{
func(ctx *sql.Context) (sql.Row, error) {
Expand Down
Loading