Skip to content

Commit 4cf9d24

Browse files
committed
reverted most changes, modified UpdateJoin to contain Updatables instead of Updaters
1 parent e03b11a commit 4cf9d24

File tree

5 files changed

+90
-199
lines changed

5 files changed

+90
-199
lines changed

sql/analyzer/apply_foreign_keys.go

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -122,44 +122,33 @@ func applyForeignKeysToNodes(ctx *sql.Context, a *Analyzer, n sql.Node, cache *f
122122
if plan.IsEmptyTable(n.Child) {
123123
return n, transform.SameTree, nil
124124
}
125-
targets := n.GetUpdateTargets()
126-
foreignKeyHandlers := make([]sql.Node, len(targets))
127-
copy(foreignKeyHandlers, targets)
128-
129-
for i, node := range targets {
130-
updateDest, err := plan.GetUpdatable(node)
131-
if err != nil {
132-
return nil, transform.SameTree, err
133-
}
125+
// TODO: UPDATE JOIN can update multiple tables. Because updatableJoinTable does not implement
126+
// sql.ForeignKeyTable, we do not currenly support FK checks for UPDATE JOIN statements.
127+
updateDest, err := plan.GetUpdatable(n.Child)
128+
if err != nil {
129+
return nil, transform.SameTree, err
130+
}
131+
fkTbl, ok := updateDest.(sql.ForeignKeyTable)
132+
// If foreign keys aren't supported then we return
133+
if !ok {
134+
return n, transform.SameTree, nil
135+
}
134136

135-
tbl, ok := updateDest.(sql.ForeignKeyTable)
136-
if !ok {
137-
continue
138-
}
139-
fkEditor, err := getForeignKeyEditor(ctx, a, tbl, cache, fkChain, false)
140-
if err != nil {
141-
return nil, transform.SameTree, err
142-
}
143-
if fkEditor == nil {
144-
continue
145-
}
146-
foreignKeyHandlers[i] = &plan.ForeignKeyHandler{
147-
Table: tbl,
148-
Sch: updateDest.Schema(),
149-
OriginalNode: targets[i],
150-
Editor: fkEditor,
151-
AllUpdaters: fkChain.GetUpdaters(),
152-
}
137+
fkEditor, err := getForeignKeyEditor(ctx, a, fkTbl, cache, fkChain, false)
138+
if err != nil {
139+
return nil, transform.SameTree, err
153140
}
154-
if n.IsJoin {
155-
return n.WithUpdateJoinTargets(foreignKeyHandlers), transform.NewTree, nil
156-
} else {
157-
newNode, err := n.WithChildren(foreignKeyHandlers...)
158-
if err != nil {
159-
return nil, transform.SameTree, err
160-
}
161-
return newNode, transform.NewTree, nil
141+
if fkEditor == nil {
142+
return n, transform.SameTree, nil
162143
}
144+
nn, err := n.WithChildren(&plan.ForeignKeyHandler{
145+
Table: fkTbl,
146+
Sch: updateDest.Schema(),
147+
OriginalNode: n.Child,
148+
Editor: fkEditor,
149+
AllUpdaters: fkChain.GetUpdaters(),
150+
})
151+
return nn, transform.NewTree, err
163152
case *plan.DeleteFrom:
164153
if plan.IsEmptyTable(n.Child) {
165154
return n, transform.SameTree, nil

sql/analyzer/assign_update_join.go

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,55 +34,63 @@ func modifyUpdateExprsForJoin(ctx *sql.Context, a *Analyzer, n sql.Node, scope *
3434
return n, transform.SameTree, nil
3535
}
3636

37-
updateJoinTargets, err := getTablesToBeUpdated(us, jn)
37+
updatables, err := updatablesByTable(us, jn)
3838
if err != nil {
3939
return nil, transform.SameTree, err
4040
}
41-
ret := n.WithUpdateJoinTargets(updateJoinTargets)
42-
ret = ret.WithJoinSchema(jn.Schema())
41+
42+
uj := plan.NewUpdateJoin(updatables, us)
43+
ret, err := n.WithChildren(uj)
44+
if err != nil {
45+
return nil, transform.SameTree, err
46+
}
47+
4348
return ret, transform.NewTree, nil
4449
}
4550

4651
return n, transform.SameTree, nil
4752
}
4853

49-
func getTablesToBeUpdated(us sql.Node, jn sql.Node) ([]sql.Node, error) {
50-
namesOfTablesToBeUpdated := getNamesOfTablesToBeUpdated(us)
51-
resolvedTables := getTablesByName(jn)
52-
tablesToBeUpdated := make([]sql.Node, len(namesOfTablesToBeUpdated))
54+
// rowUpdatersByTable maps a set of tables to their RowUpdater objects.
55+
func updatablesByTable(node sql.Node, ij sql.Node) (map[string]sql.UpdatableTable, error) {
56+
namesOfTableToBeUpdated := getTablesToBeUpdated(node)
57+
resolvedTables := getTablesByName(ij)
5358

54-
for i, tableName := range namesOfTablesToBeUpdated {
55-
resolvedTable, ok := resolvedTables[tableName]
59+
updatables := make(map[string]sql.UpdatableTable)
60+
for tableToBeUpdated, _ := range namesOfTableToBeUpdated {
61+
resolvedTable, ok := resolvedTables[tableToBeUpdated]
5662
if !ok {
57-
return nil, plan.ErrUpdateForTableNotSupported.New(tableName)
63+
return nil, plan.ErrUpdateForTableNotSupported.New(tableToBeUpdated)
5864
}
5965

6066
var table = resolvedTable.UnderlyingTable()
6167

68+
// If there is no UpdatableTable for a table being updated, error out
6269
updatable, ok := table.(sql.UpdatableTable)
6370
if !ok && updatable == nil {
64-
return nil, plan.ErrUpdateForTableNotSupported.New(tableName)
71+
return nil, plan.ErrUpdateForTableNotSupported.New(tableToBeUpdated)
6572
}
6673

6774
keyless := sql.IsKeyless(updatable.Schema())
6875
if keyless {
6976
return nil, sql.ErrUnsupportedFeature.New("error: keyless tables unsupported for UPDATE JOIN")
7077
}
71-
tablesToBeUpdated[i] = resolvedTable
78+
79+
updatables[tableToBeUpdated] = updatable
7280
}
7381

74-
return tablesToBeUpdated, nil
82+
return updatables, nil
7583
}
7684

77-
// getNamesOfTablesToBeUpdated takes a node and looks for the tables to modified by a SetField.
78-
func getNamesOfTablesToBeUpdated(node sql.Node) []string {
79-
ret := make([]string, 0)
85+
// getTablesToBeUpdated takes a node and looks for the tables to modified by a SetField.
86+
func getTablesToBeUpdated(node sql.Node) map[string]struct{} {
87+
ret := make(map[string]struct{})
8088

8189
transform.InspectExpressions(node, func(e sql.Expression) bool {
8290
switch e := e.(type) {
8391
case *expression.SetField:
8492
gf := e.LeftChild.(*expression.GetField)
85-
ret = append(ret, strings.ToLower(gf.Table()))
93+
ret[strings.ToLower(gf.Table())] = struct{}{}
8694
return false
8795
}
8896

sql/plan/update.go

Lines changed: 5 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,11 @@ var ErrUpdateUnexpectedSetResult = errors.NewKind("attempted to set field but ex
3131
// Update is a node for updating rows on tables.
3232
type Update struct {
3333
UnaryNode
34-
checks sql.CheckConstraints
35-
Ignore bool
36-
IsJoin bool
37-
updateJoinTargets []sql.Node
38-
joinSchema sql.Schema
39-
HasSingleRel bool
40-
IsProcNested bool
34+
checks sql.CheckConstraints
35+
Ignore bool
36+
IsJoin bool
37+
HasSingleRel bool
38+
IsProcNested bool
4139

4240
// Returning is a list of expressions to return after the update operation. This feature is not
4341
// supported in MySQL's syntax, but is exposed through PostgreSQL's syntax.
@@ -232,120 +230,3 @@ func (u *Update) DebugString() string {
232230
_ = pr.WriteChildren(sql.DebugString(u.Child))
233231
return pr.String()
234232
}
235-
236-
// WithUpdateJoinTargets returns a new Update node instance with the specified |targets| set as the update join targets
237-
// of the update operation
238-
func (u *Update) WithUpdateJoinTargets(targets []sql.Node) *Update {
239-
ret := *u
240-
ret.updateJoinTargets = targets
241-
return &ret
242-
}
243-
244-
// GetUpdateTargets returns the sql.Nodes representing the tables from which rows should be updated
245-
func (u *Update) GetUpdateTargets() []sql.Node {
246-
if u.IsJoin {
247-
return u.updateJoinTargets
248-
}
249-
return []sql.Node{u.Child}
250-
}
251-
252-
func (u *Update) WithJoinSchema(schema sql.Schema) *Update {
253-
ret := *u
254-
ret.joinSchema = schema
255-
return &ret
256-
}
257-
258-
func (u *Update) GetUpdaterAndSchema(ctx *sql.Context) (sql.RowUpdater, sql.Schema, error) {
259-
if u.IsJoin {
260-
updaterMap := make(map[string]sql.RowUpdater)
261-
for _, target := range u.updateJoinTargets {
262-
targetTable, err := GetUpdatable(target)
263-
if err != nil {
264-
return nil, nil, err
265-
}
266-
updaterMap[targetTable.Name()] = targetTable.Updater(ctx)
267-
}
268-
return &joinUpdater{
269-
updaterMap: updaterMap,
270-
schemaMap: RecreateTableSchemaFromJoinSchema(u.joinSchema),
271-
joinSchema: u.joinSchema,
272-
}, u.joinSchema, nil
273-
}
274-
updatable, err := GetUpdatable(u.Child)
275-
if err != nil {
276-
return nil, nil, err
277-
}
278-
return updatable.Updater(ctx), updatable.Schema(), nil
279-
}
280-
281-
type joinUpdater struct {
282-
updaterMap map[string]sql.RowUpdater
283-
schemaMap map[string]sql.Schema
284-
joinSchema sql.Schema
285-
}
286-
287-
var _ sql.RowUpdater = (*joinUpdater)(nil)
288-
289-
// StatementBegins implements the sql.TableEditor interface
290-
func (u *joinUpdater) StatementBegin(ctx *sql.Context) {
291-
for _, updater := range u.updaterMap {
292-
updater.StatementBegin(ctx)
293-
}
294-
}
295-
296-
// DiscardChanges implements the sql.TableEditor interface
297-
func (u *joinUpdater) DiscardChanges(ctx *sql.Context, errorEncountered error) error {
298-
for _, updater := range u.updaterMap {
299-
err := updater.DiscardChanges(ctx, errorEncountered)
300-
if err != nil {
301-
return err
302-
}
303-
}
304-
return nil
305-
}
306-
307-
// StatementComplete implements the sql.TableEditor interface
308-
func (u *joinUpdater) StatementComplete(ctx *sql.Context) error {
309-
for _, updater := range u.updaterMap {
310-
err := updater.StatementComplete(ctx)
311-
if err != nil {
312-
return err
313-
}
314-
}
315-
return nil
316-
}
317-
func (u *joinUpdater) Update(ctx *sql.Context, old sql.Row, new sql.Row) error {
318-
tableToOldRowMap := SplitRowIntoTableRowMap(old, u.joinSchema)
319-
tableToNewRowMap := SplitRowIntoTableRowMap(new, u.joinSchema)
320-
321-
for tableName, updater := range u.updaterMap {
322-
oldRow := tableToOldRowMap[tableName]
323-
newRow := tableToNewRowMap[tableName]
324-
schema := u.schemaMap[tableName]
325-
326-
eq, err := oldRow.Equals(ctx, newRow, schema)
327-
if err != nil {
328-
return err
329-
}
330-
331-
if !eq {
332-
err = updater.Update(ctx, oldRow, newRow)
333-
}
334-
335-
if err != nil {
336-
return err
337-
}
338-
}
339-
340-
return nil
341-
}
342-
343-
func (u *joinUpdater) Close(ctx *sql.Context) error {
344-
for _, updater := range u.updaterMap {
345-
err := updater.Close(ctx)
346-
if err != nil {
347-
return err
348-
}
349-
}
350-
return nil
351-
}

0 commit comments

Comments
 (0)