Skip to content

Commit e6d019a

Browse files
committed
implemented joinUpdater
1 parent 336c07a commit e6d019a

File tree

3 files changed

+77
-32
lines changed

3 files changed

+77
-32
lines changed

sql/analyzer/apply_foreign_keys.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,6 @@ func applyForeignKeysToNodes(ctx *sql.Context, a *Analyzer, n sql.Node, cache *f
122122
if plan.IsEmptyTable(n.Child) {
123123
return n, transform.SameTree, nil
124124
}
125-
// TODO: UPDATE JOIN can update multiple tables. Because updatableJoinTable does not implement
126-
// sql.ForeignKeyTable, we do not currenly support FK checks for UPDATE JOIN statements.
127125
targets := n.GetUpdateTargets()
128126
foreignKeyHandlers := make([]sql.Node, len(targets))
129127
copy(foreignKeyHandlers, targets)

sql/plan/update.go

Lines changed: 74 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ package plan
1616

1717
import (
1818
"fmt"
19-
2019
"gopkg.in/src-d/go-errors.v1"
2120

2221
"github.com/dolthub/go-mysql-server/sql"
@@ -170,17 +169,8 @@ func (u *Update) Expressions() []sql.Expression {
170169
return exprs
171170
}
172171

173-
func (u *Update) updateJoinTargetsResolved() bool {
174-
for _, target := range u.updateJoinTargets {
175-
if target.Resolved() == false {
176-
return false
177-
}
178-
}
179-
return true
180-
}
181-
182172
func (u *Update) Resolved() bool {
183-
return u.Child.Resolved() && u.updateJoinTargetsResolved() &&
173+
return u.Child.Resolved() &&
184174
expression.ExpressionsResolved(u.checks.ToExpressions()...) &&
185175
expression.ExpressionsResolved(u.Returning...)
186176

@@ -264,31 +254,97 @@ func (u *Update) WithJoinSchema(schema sql.Schema) *Update {
264254
return &ret
265255
}
266256

267-
func (u *Update) JoinUpdater() sql.RowUpdater {
268-
updaters := make([]sql.RowUpdater, len(u.updateJoinTargets))
269-
return &joinUpdater{
270-
updaters: updaters,
271-
joinSchema: u.joinSchema,
257+
func (u *Update) GetUpdaterAndSchema(ctx *sql.Context) (sql.RowUpdater, sql.Schema, error) {
258+
if u.IsJoin {
259+
updaterMap := make(map[string]sql.RowUpdater)
260+
for _, target := range u.updateJoinTargets {
261+
targetTable, err := GetUpdatable(target)
262+
if err != nil {
263+
return nil, nil, err
264+
}
265+
updaterMap[targetTable.Name()] = targetTable.Updater(ctx)
266+
}
267+
return &joinUpdater{
268+
updaterMap: updaterMap,
269+
schemaMap: RecreateTableSchemaFromJoinSchema(u.joinSchema),
270+
joinSchema: u.joinSchema,
271+
}, u.joinSchema, nil
272272
}
273+
updatable, err := GetUpdatable(u.Child)
274+
if err != nil {
275+
return nil, nil, err
276+
}
277+
return updatable.Updater(ctx), updatable.Schema(), nil
273278
}
274279

275280
type joinUpdater struct {
276-
updaters []sql.RowUpdater
281+
updaterMap map[string]sql.RowUpdater
282+
schemaMap map[string]sql.Schema
277283
joinSchema sql.Schema
278284
}
279285

280286
var _ sql.RowUpdater = (*joinUpdater)(nil)
281287

282-
func (u *joinUpdater) StatementBegin(ctx *sql.Context) {}
288+
// StatementBegins implements the sql.TableEditor interface
289+
func (u *joinUpdater) StatementBegin(ctx *sql.Context) {
290+
for _, updater := range u.updaterMap {
291+
updater.StatementBegin(ctx)
292+
}
293+
}
294+
295+
// DiscardChanges implements the sql.TableEditor interface
283296
func (u *joinUpdater) DiscardChanges(ctx *sql.Context, errorEncountered error) error {
297+
for _, updater := range u.updaterMap {
298+
err := updater.DiscardChanges(ctx, errorEncountered)
299+
if err != nil {
300+
return err
301+
}
302+
}
284303
return nil
285304
}
305+
306+
// StatementComplete implements the sql.TableEditor interface
286307
func (u *joinUpdater) StatementComplete(ctx *sql.Context) error {
308+
for _, updater := range u.updaterMap {
309+
err := updater.StatementComplete(ctx)
310+
if err != nil {
311+
return err
312+
}
313+
}
287314
return nil
288315
}
289316
func (u *joinUpdater) Update(ctx *sql.Context, old sql.Row, new sql.Row) error {
317+
tableToOldRowMap := SplitRowIntoTableRowMap(old, u.joinSchema)
318+
tableToNewRowMap := SplitRowIntoTableRowMap(new, u.joinSchema)
319+
320+
for tableName, updater := range u.updaterMap {
321+
oldRow := tableToOldRowMap[tableName]
322+
newRow := tableToNewRowMap[tableName]
323+
schema := u.schemaMap[tableName]
324+
325+
eq, err := oldRow.Equals(ctx, newRow, schema)
326+
if err != nil {
327+
return err
328+
}
329+
330+
if !eq {
331+
err = updater.Update(ctx, oldRow, newRow)
332+
}
333+
334+
if err != nil {
335+
return err
336+
}
337+
}
338+
290339
return nil
291340
}
341+
292342
func (u *joinUpdater) Close(ctx *sql.Context) error {
343+
for _, updater := range u.updaterMap {
344+
err := updater.Close(ctx)
345+
if err != nil {
346+
return err
347+
}
348+
}
293349
return nil
294350
}

sql/rowexec/dml.go

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -157,18 +157,9 @@ func (b *BaseBuilder) buildForeignKeyHandler(ctx *sql.Context, n *plan.ForeignKe
157157
}
158158

159159
func (b *BaseBuilder) buildUpdate(ctx *sql.Context, n *plan.Update, row sql.Row) (sql.RowIter, error) {
160-
var updater sql.RowUpdater
161-
var schema sql.Schema
162-
if n.IsJoin {
163-
updater = n.JoinUpdater()
164-
schema = n.Schema()
165-
} else {
166-
updatable, err := plan.GetUpdatable(n.Child)
167-
if err != nil {
168-
return nil, err
169-
}
170-
updater = updatable.Updater(ctx)
171-
schema = updatable.Schema()
160+
updater, schema, err := n.GetUpdaterAndSchema(ctx)
161+
if err != nil {
162+
return nil, err
172163
}
173164

174165
iter, err := b.buildNodeExec(ctx, n.Child, row)

0 commit comments

Comments
 (0)