@@ -21,15 +21,15 @@ import (
2121)
2222
2323type UpdateJoin struct {
24- Updatables map [string ]sql.UpdatableTable
24+ UpdateTargets map [string ]sql.Node
2525 UnaryNode
2626}
2727
2828// NewUpdateJoin returns an *UpdateJoin node.
29- func NewUpdateJoin (updatablesMap map [string ]sql.UpdatableTable , child sql.Node ) * UpdateJoin {
29+ func NewUpdateJoin (updateTargets map [string ]sql.Node , child sql.Node ) * UpdateJoin {
3030 return & UpdateJoin {
31- Updatables : updatablesMap ,
32- UnaryNode : UnaryNode {Child : child },
31+ UpdateTargets : updateTargets ,
32+ UnaryNode : UnaryNode {Child : child },
3333 }
3434}
3535
@@ -60,8 +60,8 @@ func (u *UpdateJoin) GetUpdatable() sql.UpdatableTable {
6060 // doesn't implement ForeignKeyTable, UpdateJoin statements don't enforce foreign key checks.
6161 // We should revamp this function so that we can communicate multiple tables being updated.
6262 return & UpdatableJoinTable {
63- updatables : u .Updatables ,
64- joinNode : u .Child .(* UpdateSource ).Child ,
63+ UpdateTargets : u .UpdateTargets ,
64+ joinNode : u .Child .(* UpdateSource ).Child ,
6565 }
6666}
6767
@@ -71,7 +71,11 @@ func (u *UpdateJoin) WithChildren(children ...sql.Node) (sql.Node, error) {
7171 return nil , sql .ErrInvalidChildrenNumber .New (u , len (children ), 1 )
7272 }
7373
74- return NewUpdateJoin (u .Updatables , children [0 ]), nil
74+ return NewUpdateJoin (u .UpdateTargets , children [0 ]), nil
75+ }
76+
77+ func (u * UpdateJoin ) WithUpdateTargets (updateTargets map [string ]sql.Node ) * UpdateJoin {
78+ return NewUpdateJoin (updateTargets , u .Child )
7579}
7680
7781func (u * UpdateJoin ) IsReadOnly () bool {
@@ -83,22 +87,26 @@ func (u *UpdateJoin) CollationCoercibility(ctx *sql.Context) (collation sql.Coll
8387 return sql .GetCoercibility (ctx , u .Child )
8488}
8589
86- func (u * UpdateJoin ) GetUpdaters (ctx * sql.Context ) map [string ]sql.RowUpdater {
87- return getUpdaters (u .Updatables , ctx )
90+ func (u * UpdateJoin ) GetUpdaters (ctx * sql.Context ) ( map [string ]sql.RowUpdater , error ) {
91+ return getUpdaters (u .UpdateTargets , ctx )
8892}
8993
90- func getUpdaters (updatables map [string ]sql.UpdatableTable , ctx * sql.Context ) map [string ]sql.RowUpdater {
94+ func getUpdaters (updateTargets map [string ]sql.Node , ctx * sql.Context ) ( map [string ]sql.RowUpdater , error ) {
9195 updaterMap := make (map [string ]sql.RowUpdater )
92- for tableName , updatable := range updatables {
96+ for tableName , updateTarget := range updateTargets {
97+ updatable , err := GetUpdatable (updateTarget )
98+ if err != nil {
99+ return nil , err
100+ }
93101 updaterMap [tableName ] = updatable .Updater (ctx )
94102 }
95- return updaterMap
103+ return updaterMap , nil
96104}
97105
98106// updatableJoinTable manages the update of multiple tables.
99107type UpdatableJoinTable struct {
100- updatables map [string ]sql.UpdatableTable
101- joinNode sql.Node
108+ UpdateTargets map [string ]sql.Node
109+ joinNode sql.Node
102110}
103111
104112var _ sql.UpdatableTable = (* UpdatableJoinTable )(nil )
@@ -135,8 +143,9 @@ func (u *UpdatableJoinTable) Collation() sql.CollationID {
135143
136144// Updater implements the sql.UpdatableTable interface.
137145func (u * UpdatableJoinTable ) Updater (ctx * sql.Context ) sql.RowUpdater {
146+ updaters , _ := getUpdaters (u .UpdateTargets , ctx )
138147 return & updatableJoinUpdater {
139- updaterMap : getUpdaters ( u . updatables , ctx ) ,
148+ updaterMap : updaters ,
140149 schemaMap : RecreateTableSchemaFromJoinSchema (u .joinNode .Schema ()),
141150 joinSchema : u .joinNode .Schema (),
142151 }
0 commit comments