Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,7 @@ func findCreateEventNode(planTree sql.Node) (*plan.CreateEvent, error) {

// finalizeIters applies the final transformations on sql.RowIter before execution.
func finalizeIters(ctx *sql.Context, analyzed sql.Node, qFlags *sql.QueryFlags, iter sql.RowIter) sql.RowIter {
iter = rowexec.AddTriggerRollbackIter(ctx, analyzed, iter)
iter = rowexec.AddTransactionCommittingIter(qFlags, iter)
iter = plan.AddTrackedRowIter(ctx, analyzed, iter)
iter = rowexec.AddExpressionCloser(analyzed, iter)
Expand Down
5,471 changes: 2,728 additions & 2,743 deletions enginetest/queries/integration_plans.go

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion sql/analyzer/rule_ids.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ const (
assignRoutinesId // assignRoutines
modifyUpdateExprsForJoinId // modifyUpdateExprsForJoin
applyUpdateAccumulatorsId // applyUpdateAccumulators
wrapWithRollbackId // wrapWithRollback
applyForeignKeysId // applyForeignKeys

// validate
Expand Down
35 changes: 17 additions & 18 deletions sql/analyzer/ruleid_string.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion sql/analyzer/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ func init() {
{applyTriggersId, applyTriggers},
{applyProceduresId, applyProcedures},
{applyUpdateAccumulatorsId, applyUpdateAccumulators},
{wrapWithRollbackId, wrapWithRollback},
{inlineSubqueryAliasRefsId, inlineSubqueryAliasRefs},
{cacheSubqueryAliasesInJoinsId, cacheSubqueryAliasesInJoins},
{BacktickDefaulColumnValueNamesId, backtickDefaultColumnValueNames},
Expand Down
38 changes: 0 additions & 38 deletions sql/analyzer/triggers.go
Original file line number Diff line number Diff line change
Expand Up @@ -517,41 +517,3 @@ func orderTriggersAndReverseAfter(triggers []*plan.CreateTrigger) []*plan.Create
func triggerEventsMatch(event plan.TriggerEvent, event2 string) bool {
return strings.ToLower((string)(event)) == strings.ToLower(event2)
}

// wrapWithRollback wraps the entire tree iff it contains a trigger, allowing rollback when a trigger errors
func wrapWithRollback(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
// Check if tree contains a TriggerExecutor
containsTrigger := false
transform.Inspect(n, func(n sql.Node) bool {
// After Triggers wrap nodes
if _, ok := n.(*plan.TriggerExecutor); ok {
containsTrigger = true
return false // done, don't bother to recurse
}

// Before Triggers on Inserts are inside Source
if n, ok := n.(*plan.InsertInto); ok {
if _, ok := n.Source.(*plan.TriggerExecutor); ok {
containsTrigger = true
return false
}
}

// Before Triggers on Delete and Update should be in children
return true
})

// No TriggerExecutor, so return same tree
if !containsTrigger {
return n, transform.SameTree, nil
}

// If we don't have a transaction session we can't do rollbacks
_, ok := ctx.Session.(sql.TransactionSession)
if !ok {
return plan.NewNoopTriggerRollback(n), transform.NewTree, nil
}

// Wrap tree with new node
return plan.NewTriggerRollback(n), transform.NewTree, nil
}
100 changes: 0 additions & 100 deletions sql/plan/trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,103 +103,3 @@ func (t *TriggerExecutor) CheckPrivileges(ctx *sql.Context, opChecker sql.Privil
func (t *TriggerExecutor) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.GetCoercibility(ctx, t.left)
}

// TriggerRollback is a node that wraps the entire tree iff it contains a trigger, creates a savepoint, and performs a
// rollback if something went wrong during execution
type TriggerRollback struct {
UnaryNode
}

var _ sql.Node = (*TriggerRollback)(nil)
var _ sql.CollationCoercible = (*TriggerRollback)(nil)

func NewTriggerRollback(child sql.Node) *TriggerRollback {
return &TriggerRollback{
UnaryNode: UnaryNode{Child: child},
}
}

func (t *TriggerRollback) WithChildren(children ...sql.Node) (sql.Node, error) {
if len(children) != 1 {
return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1)
}

return NewTriggerRollback(children[0]), nil
}

// CheckPrivileges implements the interface sql.Node.
func (t *TriggerRollback) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
return t.Child.CheckPrivileges(ctx, opChecker)
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (t *TriggerRollback) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.GetCoercibility(ctx, t.Child)
}

func (t *TriggerRollback) IsReadOnly() bool {
return t.Child.IsReadOnly()
}

func (t *TriggerRollback) String() string {
pr := sql.NewTreePrinter()
_ = pr.WriteNode("TriggerRollback()")
_ = pr.WriteChildren(t.Child.String())
return pr.String()
}

func (t *TriggerRollback) DebugString() string {
pr := sql.NewTreePrinter()
_ = pr.WriteNode("TriggerRollback")
_ = pr.WriteChildren(sql.DebugString(t.Child))
return pr.String()
}

type NoopTriggerRollback struct {
UnaryNode
}

var _ sql.Node = (*NoopTriggerRollback)(nil)
var _ sql.CollationCoercible = (*NoopTriggerRollback)(nil)

func NewNoopTriggerRollback(child sql.Node) *NoopTriggerRollback {
return &NoopTriggerRollback{
UnaryNode: UnaryNode{Child: child},
}
}

func (t *NoopTriggerRollback) WithChildren(children ...sql.Node) (sql.Node, error) {
if len(children) != 1 {
return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1)
}

return NewNoopTriggerRollback(children[0]), nil
}

// CheckPrivileges implements the interface sql.Node.
func (t *NoopTriggerRollback) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
return t.Child.CheckPrivileges(ctx, opChecker)
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (t *NoopTriggerRollback) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.GetCoercibility(ctx, t.Child)
}

func (t *NoopTriggerRollback) IsReadOnly() bool {
return true
}

func (t *NoopTriggerRollback) String() string {
pr := sql.NewTreePrinter()
_ = pr.WriteNode("TriggerRollback()")
_ = pr.WriteChildren(t.Child.String())
return pr.String()
}

func (t *NoopTriggerRollback) DebugString() string {
pr := sql.NewTreePrinter()
_ = pr.WriteNode("TriggerRollback")
_ = pr.WriteChildren(sql.DebugString(t.Child))
return pr.String()
}
3 changes: 1 addition & 2 deletions sql/rowexec/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ type ExecBuilderFunc func(ctx *sql.Context, n sql.Node, r sql.Row) (sql.RowIter,
// sql.ExecSourceRel are also built into the tree.
type BaseBuilder struct {
// if override is provided, we try to build executor with this first
override sql.NodeExecBuilder
triggerSavePointCounter int // tracks the number of save points that have been created by triggers
override sql.NodeExecBuilder
}

func (b *BaseBuilder) Build(ctx *sql.Context, n sql.Node, r sql.Row) (sql.RowIter, error) {
Expand Down
26 changes: 0 additions & 26 deletions sql/rowexec/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,32 +256,6 @@ func (b *BaseBuilder) buildDropTable(ctx *sql.Context, n *plan.DropTable, row sq
return rowIterWithOkResultWithZeroRowsAffected(), nil
}

func (b *BaseBuilder) buildTriggerRollback(ctx *sql.Context, n *plan.TriggerRollback, row sql.Row) (sql.RowIter, error) {
childIter, err := b.buildNodeExec(ctx, n.Child, row)
if err != nil {
return nil, err
}

savePointCounter := b.triggerSavePointCounter + 1
savePointName := fmt.Sprintf("%s%v", TriggerSavePointPrefix, savePointCounter)
ctx.GetLogger().Tracef("TriggerRollback creating savepoint: %s", savePointName)

ts, ok := ctx.Session.(sql.TransactionSession)
if !ok {
return nil, fmt.Errorf("expected a sql.TransactionSession, but got %T", ctx.Session)
}

if err := ts.CreateSavepoint(ctx, ctx.GetTransaction(), savePointName); err != nil {
ctx.GetLogger().WithError(err).Errorf("CreateSavepoint failed")
}
b.triggerSavePointCounter = savePointCounter

return &triggerRollbackIter{
child: childIter,
savePointName: savePointName,
}, nil
}

func (b *BaseBuilder) buildAlterIndex(ctx *sql.Context, n *plan.AlterIndex, row sql.Row) (sql.RowIter, error) {
err := b.executeAlterIndex(ctx, n)
if err != nil {
Expand Down
41 changes: 41 additions & 0 deletions sql/rowexec/dml_iters.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,47 @@ type triggerRollbackIter struct {
savePointName string
}

func containsTrigger(node sql.Node) bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good use case for query flags. When we build a trigger in planbuilder mark a trigger use, and then read that back instead of traversing the node.

// Check if tree contains a TriggerExecutor
hasTrigger := false
transform.Inspect(node, func(n sql.Node) bool {
switch nn := n.(type) {
case *plan.TriggerExecutor:
hasTrigger = true
return false
case *plan.InsertInto:
// Before Triggers on Inserts are inside Source
if _, ok := nn.Source.(*plan.TriggerExecutor); ok {
hasTrigger = true
return false
}
}
return true
})
return hasTrigger
}

func AddTriggerRollbackIter(ctx *sql.Context, node sql.Node, iter sql.RowIter) sql.RowIter {
if !containsTrigger(node) {
return iter
}

transSess, isTransSess := ctx.Session.(sql.TransactionSession)
if !isTransSess {
return iter
}

ctx.GetLogger().Tracef("TriggerRollback creating savepoint: %s", TriggerSavePointPrefix)
if err := transSess.CreateSavepoint(ctx, ctx.GetTransaction(), TriggerSavePointPrefix); err != nil {
ctx.GetLogger().WithError(err).Errorf("CreateSavepoint failed")
}

return &triggerRollbackIter{
child: iter,
savePointName: TriggerSavePointPrefix,
}
}

func (t *triggerRollbackIter) Next(ctx *sql.Context) (row sql.Row, returnErr error) {
childRow, err := t.child.Next(ctx)

Expand Down
4 changes: 0 additions & 4 deletions sql/rowexec/node_builder.gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,6 @@ func (b *BaseBuilder) buildNodeExecNoAnalyze(ctx *sql.Context, n sql.Node, row s
return b.buildHaving(ctx, n, row)
case *plan.Signal:
return b.buildSignal(ctx, n, row)
case *plan.TriggerRollback:
return b.buildTriggerRollback(ctx, n, row)
case *plan.ExternalProcedure:
return b.buildExternalProcedure(ctx, n, row)
case *plan.Into:
Expand Down Expand Up @@ -246,8 +244,6 @@ func (b *BaseBuilder) buildNodeExecNoAnalyze(ctx *sql.Context, n sql.Node, row s
return b.buildCreateIndex(ctx, n, row)
case *plan.Procedure:
return b.buildProcedure(ctx, n, row)
case *plan.NoopTriggerRollback:
return b.buildNoopTriggerRollback(ctx, n, row)
case *plan.With:
return b.buildWith(ctx, n, row)
case *plan.Project:
Expand Down
5 changes: 0 additions & 5 deletions sql/rowexec/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,6 @@ func (b *BaseBuilder) buildCommit(ctx *sql.Context, n *plan.Commit, row sql.Row)
return sql.RowsToRowIter(), nil
}

func (b *BaseBuilder) buildNoopTriggerRollback(ctx *sql.Context, n *plan.NoopTriggerRollback, row sql.Row) (sql.RowIter, error) {
return b.buildNodeExec(ctx, n.Child, row)

}

func (b *BaseBuilder) buildKill(ctx *sql.Context, n *plan.Kill, row sql.Row) (sql.RowIter, error) {
return &lazyRowIter{
func(ctx *sql.Context) (sql.Row, error) {
Expand Down