Skip to content

Commit c30dd34

Browse files
committed
modified UpdateJoin to contain target node
1 parent 4cf9d24 commit c30dd34

File tree

4 files changed

+59
-40
lines changed

4 files changed

+59
-40
lines changed

sql/analyzer/apply_foreign_keys.go

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -128,27 +128,33 @@ func applyForeignKeysToNodes(ctx *sql.Context, a *Analyzer, n sql.Node, cache *f
128128
if err != nil {
129129
return nil, transform.SameTree, err
130130
}
131-
fkTbl, ok := updateDest.(sql.ForeignKeyTable)
132-
// If foreign keys aren't supported then we return
133-
if !ok {
134-
return n, transform.SameTree, nil
135-
}
131+
switch updateDest.(type) {
132+
case *plan.UpdatableJoinTable:
136133

137-
fkEditor, err := getForeignKeyEditor(ctx, a, fkTbl, cache, fkChain, false)
138-
if err != nil {
139-
return nil, transform.SameTree, err
140-
}
141-
if fkEditor == nil {
142134
return n, transform.SameTree, nil
135+
default:
136+
fkTbl, ok := updateDest.(sql.ForeignKeyTable)
137+
// If foreign keys aren't supported then we return
138+
if !ok {
139+
return n, transform.SameTree, nil
140+
}
141+
142+
fkEditor, err := getForeignKeyEditor(ctx, a, fkTbl, cache, fkChain, false)
143+
if err != nil {
144+
return nil, transform.SameTree, err
145+
}
146+
if fkEditor == nil {
147+
return n, transform.SameTree, nil
148+
}
149+
nn, err := n.WithChildren(&plan.ForeignKeyHandler{
150+
Table: fkTbl,
151+
Sch: updateDest.Schema(),
152+
OriginalNode: n.Child,
153+
Editor: fkEditor,
154+
AllUpdaters: fkChain.GetUpdaters(),
155+
})
156+
return nn, transform.NewTree, err
143157
}
144-
nn, err := n.WithChildren(&plan.ForeignKeyHandler{
145-
Table: fkTbl,
146-
Sch: updateDest.Schema(),
147-
OriginalNode: n.Child,
148-
Editor: fkEditor,
149-
AllUpdaters: fkChain.GetUpdaters(),
150-
})
151-
return nn, transform.NewTree, err
152158
case *plan.DeleteFrom:
153159
if plan.IsEmptyTable(n.Child) {
154160
return n, transform.SameTree, nil

sql/analyzer/assign_update_join.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ func modifyUpdateExprsForJoin(ctx *sql.Context, a *Analyzer, n sql.Node, scope *
3434
return n, transform.SameTree, nil
3535
}
3636

37-
updatables, err := updatablesByTable(us, jn)
37+
updateTargets, err := updateTargetsByTable(us, jn)
3838
if err != nil {
3939
return nil, transform.SameTree, err
4040
}
4141

42-
uj := plan.NewUpdateJoin(updatables, us)
42+
uj := plan.NewUpdateJoin(updateTargets, us)
4343
ret, err := n.WithChildren(uj)
4444
if err != nil {
4545
return nil, transform.SameTree, err
@@ -52,11 +52,11 @@ func modifyUpdateExprsForJoin(ctx *sql.Context, a *Analyzer, n sql.Node, scope *
5252
}
5353

5454
// rowUpdatersByTable maps a set of tables to their RowUpdater objects.
55-
func updatablesByTable(node sql.Node, ij sql.Node) (map[string]sql.UpdatableTable, error) {
55+
func updateTargetsByTable(node sql.Node, ij sql.Node) (map[string]sql.Node, error) {
5656
namesOfTableToBeUpdated := getTablesToBeUpdated(node)
5757
resolvedTables := getTablesByName(ij)
5858

59-
updatables := make(map[string]sql.UpdatableTable)
59+
updateTargets := make(map[string]sql.Node)
6060
for tableToBeUpdated, _ := range namesOfTableToBeUpdated {
6161
resolvedTable, ok := resolvedTables[tableToBeUpdated]
6262
if !ok {
@@ -76,10 +76,10 @@ func updatablesByTable(node sql.Node, ij sql.Node) (map[string]sql.UpdatableTabl
7676
return nil, sql.ErrUnsupportedFeature.New("error: keyless tables unsupported for UPDATE JOIN")
7777
}
7878

79-
updatables[tableToBeUpdated] = updatable
79+
updateTargets[tableToBeUpdated] = resolvedTable
8080
}
8181

82-
return updatables, nil
82+
return updateTargets, nil
8383
}
8484

8585
// getTablesToBeUpdated takes a node and looks for the tables to modified by a SetField.

sql/plan/update_join.go

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ import (
2121
)
2222

2323
type UpdateJoin struct {
24-
Updatables map[string]sql.UpdatableTable
24+
UpdateTargets map[string]sql.Node
2525
UnaryNode
2626
}
2727

2828
// NewUpdateJoin returns an *UpdateJoin node.
29-
func NewUpdateJoin(updatablesMap map[string]sql.UpdatableTable, child sql.Node) *UpdateJoin {
29+
func NewUpdateJoin(updateTargets map[string]sql.Node, child sql.Node) *UpdateJoin {
3030
return &UpdateJoin{
31-
Updatables: updatablesMap,
32-
UnaryNode: UnaryNode{Child: child},
31+
UpdateTargets: updateTargets,
32+
UnaryNode: UnaryNode{Child: child},
3333
}
3434
}
3535

@@ -60,8 +60,8 @@ func (u *UpdateJoin) GetUpdatable() sql.UpdatableTable {
6060
// doesn't implement ForeignKeyTable, UpdateJoin statements don't enforce foreign key checks.
6161
// We should revamp this function so that we can communicate multiple tables being updated.
6262
return &UpdatableJoinTable{
63-
updatables: u.Updatables,
64-
joinNode: u.Child.(*UpdateSource).Child,
63+
UpdateTargets: u.UpdateTargets,
64+
joinNode: u.Child.(*UpdateSource).Child,
6565
}
6666
}
6767

@@ -71,7 +71,11 @@ func (u *UpdateJoin) WithChildren(children ...sql.Node) (sql.Node, error) {
7171
return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 1)
7272
}
7373

74-
return NewUpdateJoin(u.Updatables, children[0]), nil
74+
return NewUpdateJoin(u.UpdateTargets, children[0]), nil
75+
}
76+
77+
func (u *UpdateJoin) WithUpdateTargets(updateTargets map[string]sql.Node) *UpdateJoin {
78+
return NewUpdateJoin(updateTargets, u.Child)
7579
}
7680

7781
func (u *UpdateJoin) IsReadOnly() bool {
@@ -83,22 +87,26 @@ func (u *UpdateJoin) CollationCoercibility(ctx *sql.Context) (collation sql.Coll
8387
return sql.GetCoercibility(ctx, u.Child)
8488
}
8589

86-
func (u *UpdateJoin) GetUpdaters(ctx *sql.Context) map[string]sql.RowUpdater {
87-
return getUpdaters(u.Updatables, ctx)
90+
func (u *UpdateJoin) GetUpdaters(ctx *sql.Context) (map[string]sql.RowUpdater, error) {
91+
return getUpdaters(u.UpdateTargets, ctx)
8892
}
8993

90-
func getUpdaters(updatables map[string]sql.UpdatableTable, ctx *sql.Context) map[string]sql.RowUpdater {
94+
func getUpdaters(updateTargets map[string]sql.Node, ctx *sql.Context) (map[string]sql.RowUpdater, error) {
9195
updaterMap := make(map[string]sql.RowUpdater)
92-
for tableName, updatable := range updatables {
96+
for tableName, updateTarget := range updateTargets {
97+
updatable, err := GetUpdatable(updateTarget)
98+
if err != nil {
99+
return nil, err
100+
}
93101
updaterMap[tableName] = updatable.Updater(ctx)
94102
}
95-
return updaterMap
103+
return updaterMap, nil
96104
}
97105

98106
// updatableJoinTable manages the update of multiple tables.
99107
type UpdatableJoinTable struct {
100-
updatables map[string]sql.UpdatableTable
101-
joinNode sql.Node
108+
UpdateTargets map[string]sql.Node
109+
joinNode sql.Node
102110
}
103111

104112
var _ sql.UpdatableTable = (*UpdatableJoinTable)(nil)
@@ -135,8 +143,9 @@ func (u *UpdatableJoinTable) Collation() sql.CollationID {
135143

136144
// Updater implements the sql.UpdatableTable interface.
137145
func (u *UpdatableJoinTable) Updater(ctx *sql.Context) sql.RowUpdater {
146+
updaters, _ := getUpdaters(u.UpdateTargets, ctx)
138147
return &updatableJoinUpdater{
139-
updaterMap: getUpdaters(u.updatables, ctx),
148+
updaterMap: updaters,
140149
schemaMap: RecreateTableSchemaFromJoinSchema(u.joinNode.Schema()),
141150
joinSchema: u.joinNode.Schema(),
142151
}

sql/rowexec/dml.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,10 +416,14 @@ func (b *BaseBuilder) buildUpdateJoin(ctx *sql.Context, n *plan.UpdateJoin, row
416416
return nil, err
417417
}
418418

419+
updaters, err := n.GetUpdaters(ctx)
420+
if err != nil {
421+
return nil, err
422+
}
419423
return &updateJoinIter{
420424
updateSourceIter: ji,
421425
joinSchema: n.Child.(*plan.UpdateSource).Child.Schema(),
422-
updaters: n.GetUpdaters(ctx),
426+
updaters: updaters,
423427
caches: make(map[string]sql.KeyValueCache),
424428
disposals: make(map[string]sql.DisposeFunc),
425429
joinNode: n.Child.(*plan.UpdateSource).Child,

0 commit comments

Comments
 (0)