Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sql/analyzer/apply_foreign_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ func applyForeignKeysToNodes(ctx *sql.Context, a *Analyzer, n sql.Node, cache *f
if plan.IsEmptyTable(n.Child) {
return n, transform.SameTree, nil
}
// TODO: UPDATE JOIN can update multiple tables. Because updatableJoinTable does not implement
// sql.ForeignKeyTable, we do not currenly support FK checks for UPDATE JOIN statements.
updateDest, err := plan.GetUpdatable(n.Child)
if err != nil {
return nil, transform.SameTree, err
Expand Down
5 changes: 5 additions & 0 deletions sql/plan/update_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ func (u *UpdateJoin) DebugString() string {

// GetUpdatable returns an updateJoinTable which implements sql.UpdatableTable.
func (u *UpdateJoin) GetUpdatable() sql.UpdatableTable {
// TODO: UpdateJoin can update multiple tables, but this interface only allows for a single table.
// Additionally, updatableJoinTable doesn't implement interfaces that other parts of the code
// expect, so UpdateJoins don't always work correctly. For example, because updatableJoinTable
// doesn't implement ForeignKeyTable, UpdateJoin statements don't enforce foreign key checks.
// We should revamp this function so that we can communicate multiple tables being updated.
return &updatableJoinTable{
updaters: u.Updaters,
joinNode: u.Child.(*UpdateSource).Child,
Expand Down
94 changes: 39 additions & 55 deletions sql/planbuilder/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,11 @@ func (b *Builder) buildDelete(inScope *scope, d *ast.Delete) (outScope *scope) {
return
}

// buildUpdate builds a Update node from |u|. If the update joins tables, the returned Update node's
// children will have a JoinNode, which will later be replaced by an UpdateJoin node during analysis. We
// don't create the UpdateJoin node here, because some query plans, such as IN SUBQUERY nodes, require
// analyzer processing that converts the subquery into a join, and then requires the same logic to
// create an UpdateJoin node under the original Update node.
func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
// TODO: this shouldn't be called during ComPrepare or `PREPARE ... FROM ...` statements, but currently it is.
// The end result is that the ComDelete counter is incremented during prepare statements, which is incorrect.
Expand Down Expand Up @@ -534,44 +539,26 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
update.IsProcNested = b.ProcCtx().DbName != ""

var checks []*sql.CheckConstraint
if join, ok := outScope.node.(*plan.JoinNode); ok {
// TODO this doesn't work, a lot of the time the top node
// is a filter. This would have to go before we build the
// filter/accessory nodes. But that errors for a lot of queries.
source := plan.NewUpdateSource(
join,
ignore,
updateExprs,
)
updaters, err := rowUpdatersByTable(b.ctx, source, join)
if hasJoinNode(outScope.node) {
tablesToUpdate, err := getResolvedTablesToUpdate(b.ctx, update.Child, outScope.node)
if err != nil {
b.handleErr(err)
}
updateJoin := plan.NewUpdateJoin(updaters, source)
update.Child = updateJoin
transform.Inspect(update, func(n sql.Node) bool {
// todo maybe this should be later stage
switch n := n.(type) {
case sql.NameableNode:
if _, ok := updaters[n.Name()]; ok {
rt := getResolvedTable(n)
tableScope := inScope.push()
for _, c := range rt.Schema() {
tableScope.addColumn(scopeColumn{
db: rt.SqlDatabase.Name(),
table: strings.ToLower(n.Name()),
tableId: tableScope.tables[strings.ToLower(n.Name())],
col: strings.ToLower(c.Name),
typ: c.Type,
nullable: c.Nullable,
})
}
checks = append(checks, b.loadChecksFromTable(tableScope, rt.Table)...)
}
default:

for _, rt := range tablesToUpdate {
tableScope := inScope.push()
for _, c := range rt.Schema() {
tableScope.addColumn(scopeColumn{
db: rt.SqlDatabase.Name(),
table: strings.ToLower(rt.Name()),
tableId: tableScope.tables[strings.ToLower(rt.Name())],
col: strings.ToLower(c.Name),
typ: c.Type,
nullable: c.Nullable,
})
}
return true
})
checks = append(checks, b.loadChecksFromTable(tableScope, rt.Table)...)
}
} else {
transform.Inspect(update, func(n sql.Node) bool {
// todo maybe this should be later stage
Expand All @@ -594,35 +581,32 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
return
}

// rowUpdatersByTable maps a set of tables to their RowUpdater objects.
func rowUpdatersByTable(ctx *sql.Context, node sql.Node, ij sql.Node) (map[string]sql.RowUpdater, error) {
namesOfTableToBeUpdated := getTablesToBeUpdated(node)
resolvedTables := getTablesByName(ij)

rowUpdatersByTable := make(map[string]sql.RowUpdater)
for tableToBeUpdated, _ := range namesOfTableToBeUpdated {
resolvedTable, ok := resolvedTables[strings.ToLower(tableToBeUpdated)]
if !ok {
return nil, plan.ErrUpdateForTableNotSupported.New(tableToBeUpdated)
// hasJoinNode returns true if |node| or any child is a JoinNode.
func hasJoinNode(node sql.Node) bool {
updateJoinFound := false
transform.Inspect(node, func(n sql.Node) bool {
if _, ok := n.(*plan.JoinNode); ok {
updateJoinFound = true
}
return !updateJoinFound
})
return updateJoinFound
}

var table = resolvedTable.UnderlyingTable()
func getResolvedTablesToUpdate(_ *sql.Context, node sql.Node, ij sql.Node) (resolvedTables []*plan.ResolvedTable, err error) {
namesOfTablesToBeUpdated := getTablesToBeUpdated(node)
resolvedTablesMap := getTablesByName(ij)

// If there is no UpdatableTable for a table being updated, error out
updatable, ok := table.(sql.UpdatableTable)
if !ok && updatable == nil {
for tableToBeUpdated, _ := range namesOfTablesToBeUpdated {
resolvedTable, ok := resolvedTablesMap[strings.ToLower(tableToBeUpdated)]
if !ok {
return nil, plan.ErrUpdateForTableNotSupported.New(tableToBeUpdated)
}

keyless := sql.IsKeyless(updatable.Schema())
if keyless {
return nil, sql.ErrUnsupportedFeature.New("error: keyless tables unsupported for UPDATE JOIN")
}

rowUpdatersByTable[tableToBeUpdated] = updatable.Updater(ctx)
resolvedTables = append(resolvedTables, resolvedTable)
}

return rowUpdatersByTable, nil
return resolvedTables, nil
}

// getTablesByName takes a node and returns all found resolved tables in a map.
Expand Down
8 changes: 6 additions & 2 deletions sql/rowexec/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,12 @@ func (u *updateJoinIter) Next(ctx *sql.Context) (sql.Row, error) {
if errors.Is(err, sql.ErrKeyNotFound) {
cache.Put(hash, struct{}{})

// updateJoin counts matched rows from join output
u.accumulator.handleRowMatched()
// updateJoin counts matched rows from join output, unless a RETURNING clause
// is in use, in which case there will not be an accumulator assigned, since we
// don't need to return the count of updated rows, just the RETURNING expressions.
if u.accumulator != nil {
u.accumulator.handleRowMatched()
}

continue
} else if err != nil {
Expand Down