Skip to content

Commit 9dfae69

Browse files
committed
Minor updates to support UPDATE ... FROM in Doltgres, through the existing UpdateJoin support
1 parent b7b74d4 commit 9dfae69

File tree

5 files changed

+92
-11
lines changed

5 files changed

+92
-11
lines changed

sql/analyzer/apply_foreign_keys.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ func applyForeignKeysToNodes(ctx *sql.Context, a *Analyzer, n sql.Node, cache *f
122122
if plan.IsEmptyTable(n.Child) {
123123
return n, transform.SameTree, nil
124124
}
125+
// TODO: UPDATE JOIN can update multiple tables. Because updatableJoinTable does not implement
126+
// sql.ForeignKeyTable, we do not currenly support FK checks for UPDATE JOIN statements.
125127
updateDest, err := plan.GetUpdatable(n.Child)
126128
if err != nil {
127129
return nil, transform.SameTree, err

sql/analyzer/assign_update_join.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,18 @@ func modifyUpdateExprsForJoin(ctx *sql.Context, a *Analyzer, n sql.Node, scope *
3939
return nil, transform.SameTree, err
4040
}
4141

42+
tablesToUpdate, err := getResolvedTablesToUpdate(ctx, us, jn)
43+
if err != nil {
44+
return n, transform.SameTree, err
45+
}
46+
47+
tableNames := make([]string, len(tablesToUpdate))
48+
for i, tableToUpdate := range tablesToUpdate {
49+
tableNames[i] = tableToUpdate.Name()
50+
}
51+
4252
uj := plan.NewUpdateJoin(updaters, us)
53+
uj.TargetTables = tableNames
4354
ret, err := n.WithChildren(uj)
4455
if err != nil {
4556
return nil, transform.SameTree, err
@@ -99,3 +110,19 @@ func getTablesToBeUpdated(node sql.Node) map[string]struct{} {
99110

100111
return ret
101112
}
113+
114+
func getResolvedTablesToUpdate(_ *sql.Context, node sql.Node, ij sql.Node) (resolvedTables []*plan.ResolvedTable, err error) {
115+
namesOfTableToBeUpdated := getTablesToBeUpdated(node)
116+
resolvedTablesMap := getTablesByName(ij)
117+
118+
for tableToBeUpdated, _ := range namesOfTableToBeUpdated {
119+
resolvedTable, ok := resolvedTablesMap[strings.ToLower(tableToBeUpdated)]
120+
if !ok {
121+
return nil, plan.ErrUpdateForTableNotSupported.New(tableToBeUpdated)
122+
}
123+
124+
resolvedTables = append(resolvedTables, resolvedTable)
125+
}
126+
127+
return resolvedTables, nil
128+
}

sql/plan/update_join.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ import (
2121
)
2222

2323
type UpdateJoin struct {
24-
Updaters map[string]sql.RowUpdater
24+
Updaters map[string]sql.RowUpdater
25+
TargetTables []string
2526
UnaryNode
2627
}
2728

@@ -54,7 +55,14 @@ func (u *UpdateJoin) DebugString() string {
5455

5556
// GetUpdatable returns an updateJoinTable which implements sql.UpdatableTable.
5657
func (u *UpdateJoin) GetUpdatable() sql.UpdatableTable {
58+
// TODO: UpdateJoin can update multiple tables, but this interface only allows for a single table.
59+
// Additionally, updatableJoinTable doesn't implement interfaces that other parts of the code
60+
// expect, so UpdateJoins don't always work correctly. For example, because updatableJoinTable
61+
// doesn't implement ForeignKeyTable, we UpdateJoin statements don't properly enforce foreign key
62+
// checks. We should revamp this function so that we can support multiple tables being updated,
63+
// but right now we just return the name of the
5764
return &updatableJoinTable{
65+
name: u.TargetTables[0],
5866
updaters: u.Updaters,
5967
joinNode: u.Child.(*UpdateSource).Child,
6068
}
@@ -66,7 +74,9 @@ func (u *UpdateJoin) WithChildren(children ...sql.Node) (sql.Node, error) {
6674
return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 1)
6775
}
6876

69-
return NewUpdateJoin(u.Updaters, children[0]), nil
77+
newUpdateJoin := NewUpdateJoin(u.Updaters, children[0])
78+
newUpdateJoin.TargetTables = u.TargetTables
79+
return newUpdateJoin, nil
7080
}
7181

7282
func (u *UpdateJoin) IsReadOnly() bool {
@@ -80,6 +90,7 @@ func (u *UpdateJoin) CollationCoercibility(ctx *sql.Context) (collation sql.Coll
8090

8191
// updatableJoinTable manages the update of multiple tables.
8292
type updatableJoinTable struct {
93+
name string
8394
updaters map[string]sql.RowUpdater
8495
joinNode sql.Node
8596
}
@@ -98,7 +109,7 @@ func (u *updatableJoinTable) PartitionRows(context *sql.Context, partition sql.P
98109

99110
// Name implements the sql.UpdatableTable interface.
100111
func (u *updatableJoinTable) Name() string {
101-
panic("this method should not be called")
112+
return u.name
102113
}
103114

104115
// String implements the sql.UpdatableTable interface.

sql/planbuilder/dml.go

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -534,20 +534,29 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
534534
update.IsProcNested = b.ProcCtx().DbName != ""
535535

536536
var checks []*sql.CheckConstraint
537-
if join, ok := outScope.node.(*plan.JoinNode); ok {
538-
// TODO this doesn't work, a lot of the time the top node
539-
// is a filter. This would have to go before we build the
540-
// filter/accessory nodes. But that errors for a lot of queries.
537+
if hasJoinNode(outScope.node) {
541538
source := plan.NewUpdateSource(
542-
join,
539+
outScope.node,
543540
ignore,
544541
updateExprs,
545542
)
546-
updaters, err := rowUpdatersByTable(b.ctx, source, join)
543+
updaters, err := rowUpdatersByTable(b.ctx, source, outScope.node)
547544
if err != nil {
548545
b.handleErr(err)
549546
}
547+
548+
tablesToUpdate, err := getResolvedTablesToUpdate(b.ctx, source, outScope.node)
549+
if err != nil {
550+
b.handleErr(err)
551+
}
552+
553+
tableNames := make([]string, len(tablesToUpdate))
554+
for i, tableToUpdate := range tablesToUpdate {
555+
tableNames[i] = tableToUpdate.Name()
556+
}
557+
550558
updateJoin := plan.NewUpdateJoin(updaters, source)
559+
updateJoin.TargetTables = tableNames
551560
update.Child = updateJoin
552561
transform.Inspect(update, func(n sql.Node) bool {
553562
// todo maybe this should be later stage
@@ -594,6 +603,34 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
594603
return
595604
}
596605

606+
// hasJoinNode returns true if |node| or any child is a JoinNode.
607+
func hasJoinNode(node sql.Node) bool {
608+
updateJoinFound := false
609+
transform.Inspect(node, func(n sql.Node) bool {
610+
if _, ok := n.(*plan.JoinNode); ok {
611+
updateJoinFound = true
612+
}
613+
return !updateJoinFound
614+
})
615+
return updateJoinFound
616+
}
617+
618+
func getResolvedTablesToUpdate(_ *sql.Context, node sql.Node, ij sql.Node) (resolvedTables []*plan.ResolvedTable, err error) {
619+
namesOfTableToBeUpdated := getTablesToBeUpdated(node)
620+
resolvedTablesMap := getTablesByName(ij)
621+
622+
for tableToBeUpdated, _ := range namesOfTableToBeUpdated {
623+
resolvedTable, ok := resolvedTablesMap[strings.ToLower(tableToBeUpdated)]
624+
if !ok {
625+
return nil, plan.ErrUpdateForTableNotSupported.New(tableToBeUpdated)
626+
}
627+
628+
resolvedTables = append(resolvedTables, resolvedTable)
629+
}
630+
631+
return resolvedTables, nil
632+
}
633+
597634
// rowUpdatersByTable maps a set of tables to their RowUpdater objects.
598635
func rowUpdatersByTable(ctx *sql.Context, node sql.Node, ij sql.Node) (map[string]sql.RowUpdater, error) {
599636
namesOfTableToBeUpdated := getTablesToBeUpdated(node)

sql/rowexec/update.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,12 @@ func (u *updateJoinIter) Next(ctx *sql.Context) (sql.Row, error) {
258258
if errors.Is(err, sql.ErrKeyNotFound) {
259259
cache.Put(hash, struct{}{})
260260

261-
// updateJoin counts matched rows from join output
262-
u.accumulator.handleRowMatched()
261+
// updateJoin counts matched rows from join output, unless a RETURNING clause
262+
// is in use, in which case there will not be an accumulator assigned, since we
263+
// don't need to return the count of updated rows, just the RETURNING expressions.
264+
if u.accumulator != nil {
265+
u.accumulator.handleRowMatched()
266+
}
263267

264268
continue
265269
} else if err != nil {

0 commit comments

Comments
 (0)