Skip to content

Commit 5e9e2f3

Browse files
committed
Minor updates to support UPDATE ... FROM in Doltgres, through the existing UpdateJoin support
1 parent b7b74d4 commit 5e9e2f3

File tree

4 files changed

+52
-57
lines changed

4 files changed

+52
-57
lines changed

sql/analyzer/apply_foreign_keys.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ func applyForeignKeysToNodes(ctx *sql.Context, a *Analyzer, n sql.Node, cache *f
122122
if plan.IsEmptyTable(n.Child) {
123123
return n, transform.SameTree, nil
124124
}
125+
// TODO: UPDATE JOIN can update multiple tables. Because updatableJoinTable does not implement
126+
// sql.ForeignKeyTable, we do not currenly support FK checks for UPDATE JOIN statements.
125127
updateDest, err := plan.GetUpdatable(n.Child)
126128
if err != nil {
127129
return nil, transform.SameTree, err

sql/plan/update_join.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ func (u *UpdateJoin) DebugString() string {
5454

5555
// GetUpdatable returns an updateJoinTable which implements sql.UpdatableTable.
5656
func (u *UpdateJoin) GetUpdatable() sql.UpdatableTable {
57+
// TODO: UpdateJoin can update multiple tables, but this interface only allows for a single table.
58+
// Additionally, updatableJoinTable doesn't implement interfaces that other parts of the code
59+
// expect, so UpdateJoins don't always work correctly. For example, because updatableJoinTable
60+
// doesn't implement ForeignKeyTable, UpdateJoin statements don't enforce foreign key checks.
61+
// We should revamp this function so that we can communicate multiple tables being updated.
5762
return &updatableJoinTable{
5863
updaters: u.Updaters,
5964
joinNode: u.Child.(*UpdateSource).Child,

sql/planbuilder/dml.go

Lines changed: 39 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,11 @@ func (b *Builder) buildDelete(inScope *scope, d *ast.Delete) (outScope *scope) {
492492
return
493493
}
494494

495+
// buildUpdate builds a Update node from |u|. If the update joins tables, the returned Update node's
496+
// children will have a JoinNode, which will later be replaced by an UpdateJoin node during analysis. We
497+
// don't create the UpdateJoin node here, because some query plans, such as IN SUBQUERY nodes, require
498+
// analyzer processing that converts the subquery into a join, and then requires the same logic to
499+
// create an UpdateJoin node under the original Update node.
495500
func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
496501
// TODO: this shouldn't be called during ComPrepare or `PREPARE ... FROM ...` statements, but currently it is.
497502
// The end result is that the ComDelete counter is incremented during prepare statements, which is incorrect.
@@ -534,44 +539,26 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
534539
update.IsProcNested = b.ProcCtx().DbName != ""
535540

536541
var checks []*sql.CheckConstraint
537-
if join, ok := outScope.node.(*plan.JoinNode); ok {
538-
// TODO this doesn't work, a lot of the time the top node
539-
// is a filter. This would have to go before we build the
540-
// filter/accessory nodes. But that errors for a lot of queries.
541-
source := plan.NewUpdateSource(
542-
join,
543-
ignore,
544-
updateExprs,
545-
)
546-
updaters, err := rowUpdatersByTable(b.ctx, source, join)
542+
if hasJoinNode(outScope.node) {
543+
tablesToUpdate, err := getResolvedTablesToUpdate(b.ctx, update.Child, outScope.node)
547544
if err != nil {
548545
b.handleErr(err)
549546
}
550-
updateJoin := plan.NewUpdateJoin(updaters, source)
551-
update.Child = updateJoin
552-
transform.Inspect(update, func(n sql.Node) bool {
553-
// todo maybe this should be later stage
554-
switch n := n.(type) {
555-
case sql.NameableNode:
556-
if _, ok := updaters[n.Name()]; ok {
557-
rt := getResolvedTable(n)
558-
tableScope := inScope.push()
559-
for _, c := range rt.Schema() {
560-
tableScope.addColumn(scopeColumn{
561-
db: rt.SqlDatabase.Name(),
562-
table: strings.ToLower(n.Name()),
563-
tableId: tableScope.tables[strings.ToLower(n.Name())],
564-
col: strings.ToLower(c.Name),
565-
typ: c.Type,
566-
nullable: c.Nullable,
567-
})
568-
}
569-
checks = append(checks, b.loadChecksFromTable(tableScope, rt.Table)...)
570-
}
571-
default:
547+
548+
for _, rt := range tablesToUpdate {
549+
tableScope := inScope.push()
550+
for _, c := range rt.Schema() {
551+
tableScope.addColumn(scopeColumn{
552+
db: rt.SqlDatabase.Name(),
553+
table: strings.ToLower(rt.Name()),
554+
tableId: tableScope.tables[strings.ToLower(rt.Name())],
555+
col: strings.ToLower(c.Name),
556+
typ: c.Type,
557+
nullable: c.Nullable,
558+
})
572559
}
573-
return true
574-
})
560+
checks = append(checks, b.loadChecksFromTable(tableScope, rt.Table)...)
561+
}
575562
} else {
576563
transform.Inspect(update, func(n sql.Node) bool {
577564
// todo maybe this should be later stage
@@ -594,35 +581,32 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
594581
return
595582
}
596583

597-
// rowUpdatersByTable maps a set of tables to their RowUpdater objects.
598-
func rowUpdatersByTable(ctx *sql.Context, node sql.Node, ij sql.Node) (map[string]sql.RowUpdater, error) {
599-
namesOfTableToBeUpdated := getTablesToBeUpdated(node)
600-
resolvedTables := getTablesByName(ij)
601-
602-
rowUpdatersByTable := make(map[string]sql.RowUpdater)
603-
for tableToBeUpdated, _ := range namesOfTableToBeUpdated {
604-
resolvedTable, ok := resolvedTables[strings.ToLower(tableToBeUpdated)]
605-
if !ok {
606-
return nil, plan.ErrUpdateForTableNotSupported.New(tableToBeUpdated)
584+
// hasJoinNode returns true if |node| or any child is a JoinNode.
585+
func hasJoinNode(node sql.Node) bool {
586+
updateJoinFound := false
587+
transform.Inspect(node, func(n sql.Node) bool {
588+
if _, ok := n.(*plan.JoinNode); ok {
589+
updateJoinFound = true
607590
}
591+
return !updateJoinFound
592+
})
593+
return updateJoinFound
594+
}
608595

609-
var table = resolvedTable.UnderlyingTable()
596+
func getResolvedTablesToUpdate(_ *sql.Context, node sql.Node, ij sql.Node) (resolvedTables []*plan.ResolvedTable, err error) {
597+
namesOfTablesToBeUpdated := getTablesToBeUpdated(node)
598+
resolvedTablesMap := getTablesByName(ij)
610599

611-
// If there is no UpdatableTable for a table being updated, error out
612-
updatable, ok := table.(sql.UpdatableTable)
613-
if !ok && updatable == nil {
600+
for tableToBeUpdated, _ := range namesOfTablesToBeUpdated {
601+
resolvedTable, ok := resolvedTablesMap[strings.ToLower(tableToBeUpdated)]
602+
if !ok {
614603
return nil, plan.ErrUpdateForTableNotSupported.New(tableToBeUpdated)
615604
}
616605

617-
keyless := sql.IsKeyless(updatable.Schema())
618-
if keyless {
619-
return nil, sql.ErrUnsupportedFeature.New("error: keyless tables unsupported for UPDATE JOIN")
620-
}
621-
622-
rowUpdatersByTable[tableToBeUpdated] = updatable.Updater(ctx)
606+
resolvedTables = append(resolvedTables, resolvedTable)
623607
}
624608

625-
return rowUpdatersByTable, nil
609+
return resolvedTables, nil
626610
}
627611

628612
// getTablesByName takes a node and returns all found resolved tables in a map.

sql/rowexec/update.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,12 @@ func (u *updateJoinIter) Next(ctx *sql.Context) (sql.Row, error) {
258258
if errors.Is(err, sql.ErrKeyNotFound) {
259259
cache.Put(hash, struct{}{})
260260

261-
// updateJoin counts matched rows from join output
262-
u.accumulator.handleRowMatched()
261+
// updateJoin counts matched rows from join output, unless a RETURNING clause
262+
// is in use, in which case there will not be an accumulator assigned, since we
263+
// don't need to return the count of updated rows, just the RETURNING expressions.
264+
if u.accumulator != nil {
265+
u.accumulator.handleRowMatched()
266+
}
263267

264268
continue
265269
} else if err != nil {

0 commit comments

Comments
 (0)