Skip to content

Commit 473f514

Browse files
author
James Cor
committed
triggers
1 parent 9814e26 commit 473f514

File tree

4 files changed

+11
-24
lines changed

4 files changed

+11
-24
lines changed

engine.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -866,7 +866,7 @@ func findCreateEventNode(planTree sql.Node) (*plan.CreateEvent, error) {
866866

867867
// finalizeIters applies the final transformations on sql.RowIter before execution.
868868
func finalizeIters(ctx *sql.Context, analyzed sql.Node, qFlags *sql.QueryFlags, iter sql.RowIter) sql.RowIter {
869-
iter = rowexec.AddTriggerRollbackIter(ctx, analyzed, iter)
869+
iter = rowexec.AddTriggerRollbackIter(ctx, qFlags, iter)
870870
iter = rowexec.AddTransactionCommittingIter(qFlags, iter)
871871
iter = plan.AddTrackedRowIter(ctx, analyzed, iter)
872872
iter = rowexec.AddExpressionCloser(analyzed, iter)

sql/analyzer/triggers.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope
346346

347347
switch n := c.Node.(type) {
348348
case *plan.InsertInto:
349+
qFlags.Set(sql.QFlagTrigger)
349350
if trigger.TriggerTime == sqlparser.BeforeStr {
350351
triggerExecutor := plan.NewTriggerExecutor(n.Source, triggerLogic, plan.InsertTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{
351352
Name: trigger.TriggerName,
@@ -359,6 +360,7 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope
359360
}), transform.NewTree, nil
360361
}
361362
case *plan.Update:
363+
qFlags.Set(sql.QFlagTrigger)
362364
if trigger.TriggerTime == sqlparser.BeforeStr {
363365
triggerExecutor := plan.NewTriggerExecutor(n.Child, triggerLogic, plan.UpdateTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{
364366
Name: trigger.TriggerName,
@@ -387,6 +389,7 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope
387389
"does not support triggers; retry with single table deletes")
388390
}
389391

392+
qFlags.Set(sql.QFlagTrigger)
390393
if trigger.TriggerTime == sqlparser.BeforeStr {
391394
triggerExecutor := plan.NewTriggerExecutor(n.Child, triggerLogic, plan.DeleteTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{
392395
Name: trigger.TriggerName,

sql/query_flags.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ const (
4848
QFlagDeferProjections
4949
// QFlagUndeferrableExprs indicates that the query has expressions that cannot be deferred
5050
QFlagUndeferrableExprs
51+
QFlagTrigger
5152
)
5253

5354
type QueryFlags struct {
@@ -69,6 +70,9 @@ func (qp *QueryFlags) Unset(flag int) {
6970
}
7071

7172
func (qp *QueryFlags) IsSet(flag int) bool {
73+
if qp == nil {
74+
return false
75+
}
7276
return qp.Flags.Contains(flag)
7377
}
7478

sql/rowexec/dml_iters.go

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2023 Dolthub, Inc.
1+
// Copyright 2023-2024 Dolthub, Inc.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -33,28 +33,8 @@ type triggerRollbackIter struct {
3333
savePointName string
3434
}
3535

36-
func containsTrigger(node sql.Node) bool {
37-
// Check if tree contains a TriggerExecutor
38-
hasTrigger := false
39-
transform.Inspect(node, func(n sql.Node) bool {
40-
switch nn := n.(type) {
41-
case *plan.TriggerExecutor:
42-
hasTrigger = true
43-
return false
44-
case *plan.InsertInto:
45-
// Before Triggers on Inserts are inside Source
46-
if _, ok := nn.Source.(*plan.TriggerExecutor); ok {
47-
hasTrigger = true
48-
return false
49-
}
50-
}
51-
return true
52-
})
53-
return hasTrigger
54-
}
55-
56-
func AddTriggerRollbackIter(ctx *sql.Context, node sql.Node, iter sql.RowIter) sql.RowIter {
57-
if !containsTrigger(node) {
36+
func AddTriggerRollbackIter(ctx *sql.Context, qFlags *sql.QueryFlags, iter sql.RowIter) sql.RowIter {
37+
if !qFlags.IsSet(sql.QFlagTrigger) {
5838
return iter
5939
}
6040

0 commit comments

Comments
 (0)