Skip to content

Commit 7e253ea

Browse files
committed
modified Update to be more similar to Delete, reduces need to UpdateJoin node, need to implement joinUpdater
1 parent cc8357c commit 7e253ea

File tree

5 files changed

+133
-59
lines changed

5 files changed

+133
-59
lines changed

enginetest/queries/update_queries.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ var UpdateScriptTests = []ScriptTest{
485485
// TODO: Foreign key constraints are not honored for UDPATE ... JOIN statements
486486
Skip: true,
487487
Query: "UPDATE orders o JOIN customers c ON o.customer_id = c.id SET o.customer_id = 123 where o.customer_id != 1;",
488-
ExpectedErr: sql.ErrCheckConstraintViolated,
488+
ExpectedErr: sql.ErrForeignKeyChildViolation,
489489
},
490490
{
491491
Query: "SELECT * FROM orders;",

sql/analyzer/apply_foreign_keys.go

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -124,31 +124,43 @@ func applyForeignKeysToNodes(ctx *sql.Context, a *Analyzer, n sql.Node, cache *f
124124
}
125125
// TODO: UPDATE JOIN can update multiple tables. Because updatableJoinTable does not implement
126126
// sql.ForeignKeyTable, we do not currenly support FK checks for UPDATE JOIN statements.
127-
updateDest, err := plan.GetUpdatable(n.Child)
128-
if err != nil {
129-
return nil, transform.SameTree, err
130-
}
131-
fkTbl, ok := updateDest.(sql.ForeignKeyTable)
132-
// If foreign keys aren't supported then we return
133-
if !ok {
134-
return n, transform.SameTree, nil
135-
}
127+
targets := n.GetUpdateTargets()
128+
foreignKeyHandlers := make([]sql.Node, len(targets))
129+
copy(foreignKeyHandlers, targets)
136130

137-
fkEditor, err := getForeignKeyEditor(ctx, a, fkTbl, cache, fkChain, false)
138-
if err != nil {
139-
return nil, transform.SameTree, err
131+
for i, node := range targets {
132+
updateDest, err := plan.GetUpdatable(node)
133+
if err != nil {
134+
return nil, transform.SameTree, err
135+
}
136+
137+
tbl, ok := updateDest.(sql.ForeignKeyTable)
138+
if !ok {
139+
continue
140+
}
141+
fkEditor, err := getForeignKeyEditor(ctx, a, tbl, cache, fkChain, false)
142+
if err != nil {
143+
return nil, transform.SameTree, err
144+
}
145+
if fkEditor == nil {
146+
continue
147+
}
148+
foreignKeyHandlers[i] = &plan.ForeignKeyHandler{
149+
Table: tbl,
150+
Sch: updateDest.Schema(),
151+
OriginalNode: targets[i],
152+
AllUpdaters: fkChain.GetUpdaters(),
153+
}
140154
}
141-
if fkEditor == nil {
142-
return n, transform.SameTree, nil
155+
if n.IsJoin {
156+
return n.WithUpdateJoinTargets(foreignKeyHandlers), transform.NewTree, nil
157+
} else {
158+
newNode, err := n.WithChildren(foreignKeyHandlers...)
159+
if err != nil {
160+
return nil, transform.SameTree, err
161+
}
162+
return newNode, transform.NewTree, nil
143163
}
144-
nn, err := n.WithChildren(&plan.ForeignKeyHandler{
145-
Table: fkTbl,
146-
Sch: updateDest.Schema(),
147-
OriginalNode: n.Child,
148-
Editor: fkEditor,
149-
AllUpdaters: fkChain.GetUpdaters(),
150-
})
151-
return nn, transform.NewTree, err
152164
case *plan.DeleteFrom:
153165
if plan.IsEmptyTable(n.Child) {
154166
return n, transform.SameTree, nil

sql/analyzer/assign_update_join.go

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,63 +34,55 @@ func modifyUpdateExprsForJoin(ctx *sql.Context, a *Analyzer, n sql.Node, scope *
3434
return n, transform.SameTree, nil
3535
}
3636

37-
updaters, err := rowUpdatersByTable(ctx, us, jn)
37+
updateJoinTargets, err := getTablesToBeUpdated(us, jn)
3838
if err != nil {
3939
return nil, transform.SameTree, err
4040
}
41-
42-
uj := plan.NewUpdateJoin(updaters, us)
43-
ret, err := n.WithChildren(uj)
44-
if err != nil {
45-
return nil, transform.SameTree, err
46-
}
47-
41+
ret := n.WithUpdateJoinTargets(updateJoinTargets)
42+
ret = ret.WithJoinSchema(jn.Schema())
4843
return ret, transform.NewTree, nil
4944
}
5045

5146
return n, transform.SameTree, nil
5247
}
5348

54-
// rowUpdatersByTable maps a set of tables to their RowUpdater objects.
55-
func rowUpdatersByTable(ctx *sql.Context, node sql.Node, ij sql.Node) (map[string]sql.RowUpdater, error) {
56-
namesOfTableToBeUpdated := getTablesToBeUpdated(node)
57-
resolvedTables := getTablesByName(ij)
49+
func getTablesToBeUpdated(us sql.Node, jn sql.Node) ([]sql.Node, error) {
50+
namesOfTablesToBeUpdated := getNamesOfTablesToBeUpdated(us)
51+
resolvedTables := getTablesByName(jn)
52+
tablesToBeUpdated := make([]sql.Node, len(namesOfTablesToBeUpdated))
5853

59-
rowUpdatersByTable := make(map[string]sql.RowUpdater)
60-
for tableToBeUpdated, _ := range namesOfTableToBeUpdated {
61-
resolvedTable, ok := resolvedTables[tableToBeUpdated]
54+
for i, tableName := range namesOfTablesToBeUpdated {
55+
resolvedTable, ok := resolvedTables[tableName]
6256
if !ok {
63-
return nil, plan.ErrUpdateForTableNotSupported.New(tableToBeUpdated)
57+
return nil, plan.ErrUpdateForTableNotSupported.New(tableName)
6458
}
6559

6660
var table = resolvedTable.UnderlyingTable()
6761

68-
// If there is no UpdatableTable for a table being updated, error out
6962
updatable, ok := table.(sql.UpdatableTable)
7063
if !ok && updatable == nil {
71-
return nil, plan.ErrUpdateForTableNotSupported.New(tableToBeUpdated)
64+
return nil, plan.ErrUpdateForTableNotSupported.New(tableName)
7265
}
7366

7467
keyless := sql.IsKeyless(updatable.Schema())
7568
if keyless {
7669
return nil, sql.ErrUnsupportedFeature.New("error: keyless tables unsupported for UPDATE JOIN")
7770
}
78-
79-
rowUpdatersByTable[tableToBeUpdated] = updatable.Updater(ctx)
71+
tablesToBeUpdated[i] = resolvedTable
8072
}
8173

82-
return rowUpdatersByTable, nil
74+
return tablesToBeUpdated, nil
8375
}
8476

85-
// getTablesToBeUpdated takes a node and looks for the tables to modified by a SetField.
86-
func getTablesToBeUpdated(node sql.Node) map[string]struct{} {
87-
ret := make(map[string]struct{})
77+
// getNamesOfTablesToBeUpdated takes a node and looks for the tables to modified by a SetField.
78+
func getNamesOfTablesToBeUpdated(node sql.Node) []string {
79+
ret := make([]string, 0)
8880

8981
transform.InspectExpressions(node, func(e sql.Expression) bool {
9082
switch e := e.(type) {
9183
case *expression.SetField:
9284
gf := e.LeftChild.(*expression.GetField)
93-
ret[strings.ToLower(gf.Table())] = struct{}{}
85+
ret = append(ret, strings.ToLower(gf.Table()))
9486
return false
9587
}
9688

sql/plan/update.go

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,13 @@ var ErrUpdateUnexpectedSetResult = errors.NewKind("attempted to set field but ex
3131
// Update is a node for updating rows on tables.
3232
type Update struct {
3333
UnaryNode
34-
checks sql.CheckConstraints
35-
Ignore bool
36-
IsJoin bool
37-
HasSingleRel bool
38-
IsProcNested bool
34+
checks sql.CheckConstraints
35+
Ignore bool
36+
IsJoin bool
37+
updateJoinTargets []sql.Node
38+
joinSchema sql.Schema
39+
HasSingleRel bool
40+
IsProcNested bool
3941

4042
// Returning is a list of expressions to return after the update operation. This feature is not
4143
// supported in MySQL's syntax, but is exposed through PostgreSQL's syntax.
@@ -168,8 +170,17 @@ func (u *Update) Expressions() []sql.Expression {
168170
return exprs
169171
}
170172

173+
func (u *Update) updateJoinTargetsResolved() bool {
174+
for _, target := range u.updateJoinTargets {
175+
if target.Resolved() == false {
176+
return false
177+
}
178+
}
179+
return true
180+
}
181+
171182
func (u *Update) Resolved() bool {
172-
return u.Child.Resolved() &&
183+
return u.Child.Resolved() && u.updateJoinTargetsResolved() &&
173184
expression.ExpressionsResolved(u.checks.ToExpressions()...) &&
174185
expression.ExpressionsResolved(u.Returning...)
175186

@@ -192,6 +203,57 @@ func (u Update) WithExpressions(newExprs ...sql.Expression) (sql.Node, error) {
192203
return &u, nil
193204
}
194205

206+
// WithUpdateJoinTargets returns a new Update node instance with the specified |targets| set as the update join targets
207+
// of the update operation
208+
func (u *Update) WithUpdateJoinTargets(targets []sql.Node) *Update {
209+
ret := *u
210+
ret.updateJoinTargets = targets
211+
return &ret
212+
}
213+
214+
// GetUpdateTargets returns the sql.Nodes representing the tables from which rows should be updated
215+
func (u *Update) GetUpdateTargets() []sql.Node {
216+
if u.IsJoin {
217+
return u.updateJoinTargets
218+
}
219+
return []sql.Node{u.Child}
220+
}
221+
222+
func (u *Update) WithJoinSchema(schema sql.Schema) *Update {
223+
ret := *u
224+
ret.joinSchema = schema
225+
return &ret
226+
}
227+
228+
func (u *Update) JoinUpdater() sql.RowUpdater {
229+
updaters := make([]sql.RowUpdater, len(u.updateJoinTargets))
230+
return &joinUpdater{
231+
updaters: updaters,
232+
joinSchema: u.joinSchema,
233+
}
234+
}
235+
236+
type joinUpdater struct {
237+
updaters []sql.RowUpdater
238+
joinSchema sql.Schema
239+
}
240+
241+
var _ sql.RowUpdater = (*joinUpdater)(nil)
242+
243+
func (u *joinUpdater) StatementBegin(ctx *sql.Context) {}
244+
func (u *joinUpdater) DiscardChanges(ctx *sql.Context, errorEncountered error) error {
245+
return nil
246+
}
247+
func (u *joinUpdater) StatementComplete(ctx *sql.Context) error {
248+
return nil
249+
}
250+
func (u *joinUpdater) Update(ctx *sql.Context, old sql.Row, new sql.Row) error {
251+
return nil
252+
}
253+
func (u *joinUpdater) Close(ctx *sql.Context) error {
254+
return nil
255+
}
256+
195257
// UpdateInfo is the Info for OKResults returned by Update nodes.
196258
type UpdateInfo struct {
197259
Matched, Updated, Warnings int

sql/rowexec/dml.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,18 +157,26 @@ func (b *BaseBuilder) buildForeignKeyHandler(ctx *sql.Context, n *plan.ForeignKe
157157
}
158158

159159
func (b *BaseBuilder) buildUpdate(ctx *sql.Context, n *plan.Update, row sql.Row) (sql.RowIter, error) {
160-
updatable, err := plan.GetUpdatable(n.Child)
161-
if err != nil {
162-
return nil, err
160+
var updater sql.RowUpdater
161+
var schema sql.Schema
162+
if n.IsJoin {
163+
updater = n.JoinUpdater()
164+
schema = n.Schema()
165+
} else {
166+
updatable, err := plan.GetUpdatable(n.Child)
167+
if err != nil {
168+
return nil, err
169+
}
170+
updater = updatable.Updater(ctx)
171+
schema = updatable.Schema()
163172
}
164-
updater := updatable.Updater(ctx)
165173

166174
iter, err := b.buildNodeExec(ctx, n.Child, row)
167175
if err != nil {
168176
return nil, err
169177
}
170178

171-
return newUpdateIter(iter, updatable.Schema(), updater, n.Checks(), n.Ignore, n.Returning, n.Schema()), nil
179+
return newUpdateIter(iter, schema, updater, n.Checks(), n.Ignore, n.Returning, n.Schema()), nil
172180
}
173181

174182
func (b *BaseBuilder) buildDropForeignKey(ctx *sql.Context, n *plan.DropForeignKey, row sql.Row) (sql.RowIter, error) {

0 commit comments

Comments
 (0)