Skip to content

Commit 9fc9b43

Browse files
committed
Add support for UPDATE ... FROM
1 parent 9a22145 commit 9fc9b43

File tree

4 files changed

+311
-114
lines changed

4 files changed

+311
-114
lines changed

server/analyzer/assign_update_casts.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
package analyzer
1616

1717
import (
18-
"github.com/cockroachdb/errors"
18+
"fmt"
1919

20+
"github.com/cockroachdb/errors"
2021
"github.com/dolthub/go-mysql-server/sql"
2122
"github.com/dolthub/go-mysql-server/sql/analyzer"
2223
"github.com/dolthub/go-mysql-server/sql/expression"
@@ -49,6 +50,24 @@ func AssignUpdateCasts(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, sc
4950
if !ok {
5051
return nil, transform.NewTree, errors.Errorf("UPDATE: assumption that Foreign Key child is always UpdateSource is incorrect: %T", child.OriginalNode)
5152
}
53+
newUpdateSource, err := assignUpdateCastsHandleSource(updateSource)
54+
if err != nil {
55+
return nil, transform.NewTree, err
56+
}
57+
newHandler, err := child.WithChildren(newUpdateSource)
58+
if err != nil {
59+
return nil, transform.NewTree, err
60+
}
61+
newUpdate, err = update.WithChildren(newHandler)
62+
if err != nil {
63+
return nil, transform.NewTree, err
64+
}
65+
case *plan.UpdateJoin:
66+
updateSource, ok := child.Child.(*plan.UpdateSource)
67+
if !ok {
68+
return nil, transform.NewTree, fmt.Errorf("UPDATE: unknown source type: %T", child.Child)
69+
}
70+
5271
newUpdateSource, err := assignUpdateCastsHandleSource(updateSource)
5372
if err != nil {
5473
return nil, transform.NewTree, err

server/ast/update.go

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
package ast
1616

1717
import (
18-
"github.com/cockroachdb/errors"
19-
2018
vitess "github.com/dolthub/vitess/go/vt/sqlparser"
2119

2220
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
@@ -39,9 +37,6 @@ func nodeUpdate(ctx *Context, node *tree.Update) (update *vitess.Update, err err
3937
}
4038
}
4139

42-
if len(node.From) > 0 {
43-
return nil, errors.Errorf("FROM is not yet supported")
44-
}
4540
with, err := nodeWith(ctx, node.With)
4641
if err != nil {
4742
return nil, err
@@ -50,6 +45,27 @@ func nodeUpdate(ctx *Context, node *tree.Update) (update *vitess.Update, err err
5045
if err != nil {
5146
return nil, err
5247
}
48+
49+
tableExprs := vitess.TableExprs{table}
50+
if len(node.From) > 0 {
51+
vitessTableExprs := make(vitess.TableExprs, len(node.From))
52+
for i, tableExpr := range node.From {
53+
vitessTableExpr, err := nodeTableExpr(ctx, tableExpr)
54+
if err != nil {
55+
return nil, err
56+
}
57+
vitessTableExprs[i] = vitessTableExpr
58+
}
59+
60+
tableExprs = []vitess.TableExpr{
61+
&vitess.JoinTableExpr{
62+
Join: vitess.JoinStr,
63+
LeftExpr: buildJoinTableExpressionTree(ctx, vitessTableExprs),
64+
RightExpr: table,
65+
},
66+
}
67+
}
68+
5369
exprs, err := nodeUpdateExprs(ctx, node.Exprs)
5470
if err != nil {
5571
return nil, err
@@ -67,7 +83,7 @@ func nodeUpdate(ctx *Context, node *tree.Update) (update *vitess.Update, err err
6783
return nil, err
6884
}
6985
return &vitess.Update{
70-
TableExprs: vitess.TableExprs{table},
86+
TableExprs: tableExprs,
7187
With: with,
7288
Exprs: exprs,
7389
Where: where,
@@ -76,3 +92,23 @@ func nodeUpdate(ctx *Context, node *tree.Update) (update *vitess.Update, err err
7692
Returning: returningExprs,
7793
}, nil
7894
}
95+
96+
// buildJoinTableExpressionTree returns an expression tree of JoinTableExprs with |tableExprs| as the
97+
// leaf nodes. If |tableExprs| is empty or nil, then nil is returned.
98+
func buildJoinTableExpressionTree(ctx *Context, tableExprs vitess.TableExprs) vitess.TableExpr {
99+
switch len(tableExprs) {
100+
case 0:
101+
return nil
102+
case 1:
103+
return tableExprs[0]
104+
case 2:
105+
return &vitess.JoinTableExpr{
106+
Join: vitess.JoinStr,
107+
LeftExpr: tableExprs[0],
108+
RightExpr: tableExprs[1],
109+
}
110+
default:
111+
subtree := buildJoinTableExpressionTree(ctx, tableExprs[0:2])
112+
return buildJoinTableExpressionTree(ctx, append(tableExprs[2:], subtree))
113+
}
114+
}

0 commit comments

Comments
 (0)