From 5e9e2f3405892259089b08909735378166685aa5 Mon Sep 17 00:00:00 2001 From: Jason Fulghum Date: Thu, 5 Jun 2025 16:33:39 -0700 Subject: [PATCH] Minor updates to support UPDATE ... FROM in Doltgres, through the existing UpdateJoin support --- sql/analyzer/apply_foreign_keys.go | 2 + sql/plan/update_join.go | 5 ++ sql/planbuilder/dml.go | 94 +++++++++++++----------------- sql/rowexec/update.go | 8 ++- 4 files changed, 52 insertions(+), 57 deletions(-) diff --git a/sql/analyzer/apply_foreign_keys.go b/sql/analyzer/apply_foreign_keys.go index e958799bcc..166888c8f1 100644 --- a/sql/analyzer/apply_foreign_keys.go +++ b/sql/analyzer/apply_foreign_keys.go @@ -122,6 +122,8 @@ func applyForeignKeysToNodes(ctx *sql.Context, a *Analyzer, n sql.Node, cache *f if plan.IsEmptyTable(n.Child) { return n, transform.SameTree, nil } + // TODO: UPDATE JOIN can update multiple tables. Because updatableJoinTable does not implement + // sql.ForeignKeyTable, we do not currenly support FK checks for UPDATE JOIN statements. updateDest, err := plan.GetUpdatable(n.Child) if err != nil { return nil, transform.SameTree, err diff --git a/sql/plan/update_join.go b/sql/plan/update_join.go index 814e953a26..d8da167fa8 100644 --- a/sql/plan/update_join.go +++ b/sql/plan/update_join.go @@ -54,6 +54,11 @@ func (u *UpdateJoin) DebugString() string { // GetUpdatable returns an updateJoinTable which implements sql.UpdatableTable. func (u *UpdateJoin) GetUpdatable() sql.UpdatableTable { + // TODO: UpdateJoin can update multiple tables, but this interface only allows for a single table. + // Additionally, updatableJoinTable doesn't implement interfaces that other parts of the code + // expect, so UpdateJoins don't always work correctly. For example, because updatableJoinTable + // doesn't implement ForeignKeyTable, UpdateJoin statements don't enforce foreign key checks. + // We should revamp this function so that we can communicate multiple tables being updated. return &updatableJoinTable{ updaters: u.Updaters, joinNode: u.Child.(*UpdateSource).Child, diff --git a/sql/planbuilder/dml.go b/sql/planbuilder/dml.go index 60b4ef9090..4752633dc3 100644 --- a/sql/planbuilder/dml.go +++ b/sql/planbuilder/dml.go @@ -492,6 +492,11 @@ func (b *Builder) buildDelete(inScope *scope, d *ast.Delete) (outScope *scope) { return } +// buildUpdate builds a Update node from |u|. If the update joins tables, the returned Update node's +// children will have a JoinNode, which will later be replaced by an UpdateJoin node during analysis. We +// don't create the UpdateJoin node here, because some query plans, such as IN SUBQUERY nodes, require +// analyzer processing that converts the subquery into a join, and then requires the same logic to +// create an UpdateJoin node under the original Update node. func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) { // TODO: this shouldn't be called during ComPrepare or `PREPARE ... FROM ...` statements, but currently it is. // The end result is that the ComDelete counter is incremented during prepare statements, which is incorrect. @@ -534,44 +539,26 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) { update.IsProcNested = b.ProcCtx().DbName != "" var checks []*sql.CheckConstraint - if join, ok := outScope.node.(*plan.JoinNode); ok { - // TODO this doesn't work, a lot of the time the top node - // is a filter. This would have to go before we build the - // filter/accessory nodes. But that errors for a lot of queries. - source := plan.NewUpdateSource( - join, - ignore, - updateExprs, - ) - updaters, err := rowUpdatersByTable(b.ctx, source, join) + if hasJoinNode(outScope.node) { + tablesToUpdate, err := getResolvedTablesToUpdate(b.ctx, update.Child, outScope.node) if err != nil { b.handleErr(err) } - updateJoin := plan.NewUpdateJoin(updaters, source) - update.Child = updateJoin - transform.Inspect(update, func(n sql.Node) bool { - // todo maybe this should be later stage - switch n := n.(type) { - case sql.NameableNode: - if _, ok := updaters[n.Name()]; ok { - rt := getResolvedTable(n) - tableScope := inScope.push() - for _, c := range rt.Schema() { - tableScope.addColumn(scopeColumn{ - db: rt.SqlDatabase.Name(), - table: strings.ToLower(n.Name()), - tableId: tableScope.tables[strings.ToLower(n.Name())], - col: strings.ToLower(c.Name), - typ: c.Type, - nullable: c.Nullable, - }) - } - checks = append(checks, b.loadChecksFromTable(tableScope, rt.Table)...) - } - default: + + for _, rt := range tablesToUpdate { + tableScope := inScope.push() + for _, c := range rt.Schema() { + tableScope.addColumn(scopeColumn{ + db: rt.SqlDatabase.Name(), + table: strings.ToLower(rt.Name()), + tableId: tableScope.tables[strings.ToLower(rt.Name())], + col: strings.ToLower(c.Name), + typ: c.Type, + nullable: c.Nullable, + }) } - return true - }) + checks = append(checks, b.loadChecksFromTable(tableScope, rt.Table)...) + } } else { transform.Inspect(update, func(n sql.Node) bool { // todo maybe this should be later stage @@ -594,35 +581,32 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) { return } -// rowUpdatersByTable maps a set of tables to their RowUpdater objects. -func rowUpdatersByTable(ctx *sql.Context, node sql.Node, ij sql.Node) (map[string]sql.RowUpdater, error) { - namesOfTableToBeUpdated := getTablesToBeUpdated(node) - resolvedTables := getTablesByName(ij) - - rowUpdatersByTable := make(map[string]sql.RowUpdater) - for tableToBeUpdated, _ := range namesOfTableToBeUpdated { - resolvedTable, ok := resolvedTables[strings.ToLower(tableToBeUpdated)] - if !ok { - return nil, plan.ErrUpdateForTableNotSupported.New(tableToBeUpdated) +// hasJoinNode returns true if |node| or any child is a JoinNode. +func hasJoinNode(node sql.Node) bool { + updateJoinFound := false + transform.Inspect(node, func(n sql.Node) bool { + if _, ok := n.(*plan.JoinNode); ok { + updateJoinFound = true } + return !updateJoinFound + }) + return updateJoinFound +} - var table = resolvedTable.UnderlyingTable() +func getResolvedTablesToUpdate(_ *sql.Context, node sql.Node, ij sql.Node) (resolvedTables []*plan.ResolvedTable, err error) { + namesOfTablesToBeUpdated := getTablesToBeUpdated(node) + resolvedTablesMap := getTablesByName(ij) - // If there is no UpdatableTable for a table being updated, error out - updatable, ok := table.(sql.UpdatableTable) - if !ok && updatable == nil { + for tableToBeUpdated, _ := range namesOfTablesToBeUpdated { + resolvedTable, ok := resolvedTablesMap[strings.ToLower(tableToBeUpdated)] + if !ok { return nil, plan.ErrUpdateForTableNotSupported.New(tableToBeUpdated) } - keyless := sql.IsKeyless(updatable.Schema()) - if keyless { - return nil, sql.ErrUnsupportedFeature.New("error: keyless tables unsupported for UPDATE JOIN") - } - - rowUpdatersByTable[tableToBeUpdated] = updatable.Updater(ctx) + resolvedTables = append(resolvedTables, resolvedTable) } - return rowUpdatersByTable, nil + return resolvedTables, nil } // getTablesByName takes a node and returns all found resolved tables in a map. diff --git a/sql/rowexec/update.go b/sql/rowexec/update.go index 2c4cf4eff1..4095465cbf 100644 --- a/sql/rowexec/update.go +++ b/sql/rowexec/update.go @@ -258,8 +258,12 @@ func (u *updateJoinIter) Next(ctx *sql.Context) (sql.Row, error) { if errors.Is(err, sql.ErrKeyNotFound) { cache.Put(hash, struct{}{}) - // updateJoin counts matched rows from join output - u.accumulator.handleRowMatched() + // updateJoin counts matched rows from join output, unless a RETURNING clause + // is in use, in which case there will not be an accumulator assigned, since we + // don't need to return the count of updated rows, just the RETURNING expressions. + if u.accumulator != nil { + u.accumulator.handleRowMatched() + } continue } else if err != nil {