Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions enginetest/memory_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,22 +202,37 @@ func TestSingleQueryPrepared(t *testing.T) {

// Convenience test for debugging a single query. Unskip and set to the desired query.
func TestSingleScript(t *testing.T) {
t.Skip()
//t.Skip()
var scripts = []queries.ScriptTest{
{
Name: "AS OF propagates to nested CALLs",
SetUpScript: []string{},
Dialect: "mysql",
Name: "UPDATE join – multiple tables, with trigger",
SetUpScript: []string{
"create table customers (id int primary key, name text, tier text)",
"create table orders (id int primary key, customer_id int, status text)",
"create table trigger_log (msg text)",
`CREATE TRIGGER after_orders_update after update on orders for each row
begin
insert into trigger_log (msg) values(
concat('Order ', OLD.id, ' status changed from ', OLD.status, ' to ', NEW.status));
end;`,
`Create trigger after_customers_update after update on customers for each row
begin
insert into trigger_log (msg) values(
concat('Customer ', OLD.id, ' tier changed from ', OLD.tier, ' to ', NEW.tier));
end;`,
"insert into customers values(1, 'Alice', 'silver'), (2, 'Bob', 'gold');",
"insert into orders values (101, 1, 'pending'), (102, 2, 'pending');",
"update customers c join orders o on c.id = o.customer_id set c.tier = 'platinum', o.status = 'shipped' where o.status = 'pending'",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "create procedure create_proc() create table t (i int primary key, j int);",
Expected: []sql.Row{
{types.NewOkResult(0)},
},
},
{
Query: "call create_proc()",
Query: "SELECT * FROM trigger_log order by msg;",
Expected: []sql.Row{
{types.NewOkResult(0)},
{"Customer 1 tier changed from silver to platinum"},
{"Customer 2 tier changed from gold to platinum"},
{"Order 101 status changed from pending to shipped"},
{"Order 102 status changed from pending to shipped"},
},
},
},
Expand All @@ -232,8 +247,8 @@ func TestSingleScript(t *testing.T) {
panic(err)
}

//engine.EngineAnalyzer().Debug = true
//engine.EngineAnalyzer().Verbose = true
engine.EngineAnalyzer().Debug = true
engine.EngineAnalyzer().Verbose = true

enginetest.TestScriptWithEngine(t, engine, harness, test)
}
Expand Down
4 changes: 4 additions & 0 deletions sql/analyzer/aliases.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ func getTableAliases(n sql.Node, scope *plan.Scope) (TableAliases, error) {
var recScope *plan.Scope
if !scope.IsEmpty() {
recScope = recScope.WithMemos(scope.Memos)
recScope.InUpdateJoin = scope.InUpdateJoin
}

aliasFn = func(node sql.Node) bool {
Expand All @@ -179,6 +180,9 @@ func getTableAliases(n sql.Node, scope *plan.Scope) (TableAliases, error) {
case *plan.RecursiveCte:
case sql.NameableNode:
analysisErr = passAliases.addUnqualified(at.Name(), t)
if scope != nil && scope.InUpdateJoin {
analysisErr = nil
}
case *plan.UnresolvedTable:
panic("Table not resolved")
default:
Expand Down
5 changes: 5 additions & 0 deletions sql/analyzer/fix_exec_indexes.go
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,11 @@ func fixExprToScope(e sql.Expression, scopes ...*idxScope) sql.Expression {
// don't have the destination schema, and column references in default values are determined in the build phase)

idx, _ := newScope.getIdxId(e.Id(), e.String())

if e.String() == "old.id" {
print()
}

if idx >= 0 {
return e.WithIndex(idx), transform.NewTree, nil
}
Expand Down
24 changes: 24 additions & 0 deletions sql/analyzer/triggers.go
Original file line number Diff line number Diff line change
Expand Up @@ -487,9 +487,33 @@ func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
plan.NewSubqueryAlias("new", "", updateSrc.Child),
),
)

updateTargets := n.(*plan.Update).Child.(*plan.UpdateJoin).UpdateTargets
if proj, isProj := updateSrc.Child.(*plan.Project); isProj {
oldExprs := make([]sql.Expression, len(proj.Expressions()))
newExprs := make([]sql.Expression, len(proj.Expressions()))
for i, expr := range proj.Expressions() {
if gf, isGf := expr.(*expression.GetField); isGf {
if tbl, ok := updateTargets[gf.Table()]; ok {
if tbl.(*plan.ResolvedTable).Name() == trigger.Table.(*plan.ResolvedTable).Name() {
oldExprs[i] = gf.WithTable("old")
newExprs[i] = gf.WithTable("new")
continue
}
}
}
oldExprs[i] = expr
newExprs[i] = expr
}
scopeNode.Child = plan.NewCrossJoin(
plan.NewProject(oldExprs, proj.Child),
plan.NewProject(newExprs, proj.Child),
)
}
}
// Triggers are wrapped in prepend nodes, which means that the parent scope is included
s := (*plan.Scope)(nil).NewScope(scopeNode).WithMemos(scope.Memo(n).MemoNodes()).WithProcedureCache(scope.ProcedureCache())
s.InUpdateJoin = true
triggerLogic, _, err = a.analyzeWithSelector(ctx, trigger.Body, s, SelectAllBatches, DefaultRuleSelector, qFlags)
case sqlparser.DeleteStr:
scopeNode := plan.NewProject(
Expand Down
2 changes: 2 additions & 0 deletions sql/plan/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type Scope struct {
JoinTrees []string

inInsertSource bool
InUpdateJoin bool
}

func (s *Scope) IsEmpty() bool {
Expand Down Expand Up @@ -79,6 +80,7 @@ func (s *Scope) NewScope(node sql.Node) *Scope {
recursionDepth: s.recursionDepth + 1,
Procedures: s.Procedures,
joinSiblings: s.joinSiblings,
InUpdateJoin: s.InUpdateJoin,
}
}

Expand Down
Loading