@@ -16,7 +16,6 @@ package plan
1616
1717import (
1818 "fmt"
19-
2019 "gopkg.in/src-d/go-errors.v1"
2120
2221 "github.com/dolthub/go-mysql-server/sql"
@@ -170,17 +169,8 @@ func (u *Update) Expressions() []sql.Expression {
170169 return exprs
171170}
172171
173- func (u * Update ) updateJoinTargetsResolved () bool {
174- for _ , target := range u .updateJoinTargets {
175- if target .Resolved () == false {
176- return false
177- }
178- }
179- return true
180- }
181-
182172func (u * Update ) Resolved () bool {
183- return u .Child .Resolved () && u . updateJoinTargetsResolved () &&
173+ return u .Child .Resolved () &&
184174 expression .ExpressionsResolved (u .checks .ToExpressions ()... ) &&
185175 expression .ExpressionsResolved (u .Returning ... )
186176
@@ -264,31 +254,97 @@ func (u *Update) WithJoinSchema(schema sql.Schema) *Update {
264254 return & ret
265255}
266256
267- func (u * Update ) JoinUpdater () sql.RowUpdater {
268- updaters := make ([]sql.RowUpdater , len (u .updateJoinTargets ))
269- return & joinUpdater {
270- updaters : updaters ,
271- joinSchema : u .joinSchema ,
257+ func (u * Update ) GetUpdaterAndSchema (ctx * sql.Context ) (sql.RowUpdater , sql.Schema , error ) {
258+ if u .IsJoin {
259+ updaterMap := make (map [string ]sql.RowUpdater )
260+ for _ , target := range u .updateJoinTargets {
261+ targetTable , err := GetUpdatable (target )
262+ if err != nil {
263+ return nil , nil , err
264+ }
265+ updaterMap [targetTable .Name ()] = targetTable .Updater (ctx )
266+ }
267+ return & joinUpdater {
268+ updaterMap : updaterMap ,
269+ schemaMap : RecreateTableSchemaFromJoinSchema (u .joinSchema ),
270+ joinSchema : u .joinSchema ,
271+ }, u .joinSchema , nil
272272 }
273+ updatable , err := GetUpdatable (u .Child )
274+ if err != nil {
275+ return nil , nil , err
276+ }
277+ return updatable .Updater (ctx ), updatable .Schema (), nil
273278}
274279
275280type joinUpdater struct {
276- updaters []sql.RowUpdater
281+ updaterMap map [string ]sql.RowUpdater
282+ schemaMap map [string ]sql.Schema
277283 joinSchema sql.Schema
278284}
279285
280286var _ sql.RowUpdater = (* joinUpdater )(nil )
281287
282- func (u * joinUpdater ) StatementBegin (ctx * sql.Context ) {}
288+ // StatementBegins implements the sql.TableEditor interface
289+ func (u * joinUpdater ) StatementBegin (ctx * sql.Context ) {
290+ for _ , updater := range u .updaterMap {
291+ updater .StatementBegin (ctx )
292+ }
293+ }
294+
295+ // DiscardChanges implements the sql.TableEditor interface
283296func (u * joinUpdater ) DiscardChanges (ctx * sql.Context , errorEncountered error ) error {
297+ for _ , updater := range u .updaterMap {
298+ err := updater .DiscardChanges (ctx , errorEncountered )
299+ if err != nil {
300+ return err
301+ }
302+ }
284303 return nil
285304}
305+
306+ // StatementComplete implements the sql.TableEditor interface
286307func (u * joinUpdater ) StatementComplete (ctx * sql.Context ) error {
308+ for _ , updater := range u .updaterMap {
309+ err := updater .StatementComplete (ctx )
310+ if err != nil {
311+ return err
312+ }
313+ }
287314 return nil
288315}
289316func (u * joinUpdater ) Update (ctx * sql.Context , old sql.Row , new sql.Row ) error {
317+ tableToOldRowMap := SplitRowIntoTableRowMap (old , u .joinSchema )
318+ tableToNewRowMap := SplitRowIntoTableRowMap (new , u .joinSchema )
319+
320+ for tableName , updater := range u .updaterMap {
321+ oldRow := tableToOldRowMap [tableName ]
322+ newRow := tableToNewRowMap [tableName ]
323+ schema := u .schemaMap [tableName ]
324+
325+ eq , err := oldRow .Equals (ctx , newRow , schema )
326+ if err != nil {
327+ return err
328+ }
329+
330+ if ! eq {
331+ err = updater .Update (ctx , oldRow , newRow )
332+ }
333+
334+ if err != nil {
335+ return err
336+ }
337+ }
338+
290339 return nil
291340}
341+
292342func (u * joinUpdater ) Close (ctx * sql.Context ) error {
343+ for _ , updater := range u .updaterMap {
344+ err := updater .Close (ctx )
345+ if err != nil {
346+ return err
347+ }
348+ }
293349 return nil
294350}
0 commit comments