Skip to content

Commit f6dd664

Browse files
authored
Merge pull request #1536 from dolthub/fulghum/update_from
Add support for `UPDATE ... FROM`
2 parents 44f3944 + 49fd9f0 commit f6dd664

File tree

5 files changed

+356
-117
lines changed

5 files changed

+356
-117
lines changed

server/analyzer/assign_triggers.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,15 @@ func getTriggerInformation(ctx *sql.Context, node sql.Node) (sch sql.Schema, bef
9595
return nil, nil, nil, err
9696
}
9797
case *plan.Update:
98+
// TODO: If there is a JoinNode in here, then don't bother calling GetUpdatable, because
99+
// it doesn't currently return a type that can be used to query trigger information.
100+
// We need to rework the plan.GetUpdatable() API to support returning multiple
101+
// update targets and to return types that are compatible with the interfaces
102+
// Doltgres needs in order to populate trigger information.
103+
if hasJoinNode(node) {
104+
return nil, nil, nil, nil
105+
}
106+
98107
tbl, err = plan.GetUpdatable(node.Child)
99108
if err != nil {
100109
return nil, nil, nil, err
@@ -168,6 +177,18 @@ func getTriggerInformation(ctx *sql.Context, node sql.Node) (sch sql.Schema, bef
168177
return tbl.Schema(), beforeTrigs, afterTrigs, nil
169178
}
170179

180+
// hasJoinNode returns true if |node| or any child is a JoinNode.
181+
func hasJoinNode(node sql.Node) bool {
182+
updateJoinFound := false
183+
transform.Inspect(node, func(n sql.Node) bool {
184+
if _, ok := n.(*plan.JoinNode); ok {
185+
updateJoinFound = true
186+
}
187+
return !updateJoinFound
188+
})
189+
return updateJoinFound
190+
}
191+
171192
// getTriggerSource returns the trigger's source node.
172193
func getTriggerSource(node sql.Node) sql.Node {
173194
switch node := node.(type) {

server/analyzer/assign_update_casts.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
package analyzer
1616

1717
import (
18-
"github.com/cockroachdb/errors"
18+
"fmt"
1919

20+
"github.com/cockroachdb/errors"
2021
"github.com/dolthub/go-mysql-server/sql"
2122
"github.com/dolthub/go-mysql-server/sql/analyzer"
2223
"github.com/dolthub/go-mysql-server/sql/expression"
@@ -49,6 +50,24 @@ func AssignUpdateCasts(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, sc
4950
if !ok {
5051
return nil, transform.NewTree, errors.Errorf("UPDATE: assumption that Foreign Key child is always UpdateSource is incorrect: %T", child.OriginalNode)
5152
}
53+
newUpdateSource, err := assignUpdateCastsHandleSource(updateSource)
54+
if err != nil {
55+
return nil, transform.NewTree, err
56+
}
57+
newHandler, err := child.WithChildren(newUpdateSource)
58+
if err != nil {
59+
return nil, transform.NewTree, err
60+
}
61+
newUpdate, err = update.WithChildren(newHandler)
62+
if err != nil {
63+
return nil, transform.NewTree, err
64+
}
65+
case *plan.UpdateJoin:
66+
updateSource, ok := child.Child.(*plan.UpdateSource)
67+
if !ok {
68+
return nil, transform.NewTree, fmt.Errorf("UPDATE: unknown source type: %T", child.Child)
69+
}
70+
5271
newUpdateSource, err := assignUpdateCastsHandleSource(updateSource)
5372
if err != nil {
5473
return nil, transform.NewTree, err

server/ast/update.go

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
package ast
1616

1717
import (
18-
"github.com/cockroachdb/errors"
19-
2018
vitess "github.com/dolthub/vitess/go/vt/sqlparser"
2119

2220
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
@@ -39,9 +37,6 @@ func nodeUpdate(ctx *Context, node *tree.Update) (update *vitess.Update, err err
3937
}
4038
}
4139

42-
if len(node.From) > 0 {
43-
return nil, errors.Errorf("FROM is not yet supported")
44-
}
4540
with, err := nodeWith(ctx, node.With)
4641
if err != nil {
4742
return nil, err
@@ -50,6 +45,27 @@ func nodeUpdate(ctx *Context, node *tree.Update) (update *vitess.Update, err err
5045
if err != nil {
5146
return nil, err
5247
}
48+
49+
tableExprs := vitess.TableExprs{table}
50+
if len(node.From) > 0 {
51+
vitessTableExprs := make(vitess.TableExprs, len(node.From))
52+
for i, tableExpr := range node.From {
53+
vitessTableExpr, err := nodeTableExpr(ctx, tableExpr)
54+
if err != nil {
55+
return nil, err
56+
}
57+
vitessTableExprs[i] = vitessTableExpr
58+
}
59+
60+
tableExprs = []vitess.TableExpr{
61+
&vitess.JoinTableExpr{
62+
Join: vitess.JoinStr,
63+
LeftExpr: buildJoinTableExpressionTree(ctx, vitessTableExprs),
64+
RightExpr: table,
65+
},
66+
}
67+
}
68+
5369
exprs, err := nodeUpdateExprs(ctx, node.Exprs)
5470
if err != nil {
5571
return nil, err
@@ -67,7 +83,7 @@ func nodeUpdate(ctx *Context, node *tree.Update) (update *vitess.Update, err err
6783
return nil, err
6884
}
6985
return &vitess.Update{
70-
TableExprs: vitess.TableExprs{table},
86+
TableExprs: tableExprs,
7187
With: with,
7288
Exprs: exprs,
7389
Where: where,
@@ -76,3 +92,23 @@ func nodeUpdate(ctx *Context, node *tree.Update) (update *vitess.Update, err err
7692
Returning: returningExprs,
7793
}, nil
7894
}
95+
96+
// buildJoinTableExpressionTree returns an expression tree of JoinTableExprs with |tableExprs| as the
97+
// leaf nodes. If |tableExprs| is empty or nil, then nil is returned.
98+
func buildJoinTableExpressionTree(ctx *Context, tableExprs vitess.TableExprs) vitess.TableExpr {
99+
switch len(tableExprs) {
100+
case 0:
101+
return nil
102+
case 1:
103+
return tableExprs[0]
104+
case 2:
105+
return &vitess.JoinTableExpr{
106+
Join: vitess.JoinStr,
107+
LeftExpr: tableExprs[0],
108+
RightExpr: tableExprs[1],
109+
}
110+
default:
111+
subtree := buildJoinTableExpressionTree(ctx, tableExprs[0:2])
112+
return buildJoinTableExpressionTree(ctx, append(tableExprs[2:], subtree))
113+
}
114+
}

0 commit comments

Comments
 (0)