@@ -16,7 +16,6 @@ package plan
16
16
17
17
import (
18
18
"fmt"
19
-
20
19
"gopkg.in/src-d/go-errors.v1"
21
20
22
21
"github.com/dolthub/go-mysql-server/sql"
@@ -170,17 +169,8 @@ func (u *Update) Expressions() []sql.Expression {
170
169
return exprs
171
170
}
172
171
173
- func (u * Update ) updateJoinTargetsResolved () bool {
174
- for _ , target := range u .updateJoinTargets {
175
- if target .Resolved () == false {
176
- return false
177
- }
178
- }
179
- return true
180
- }
181
-
182
172
func (u * Update ) Resolved () bool {
183
- return u .Child .Resolved () && u . updateJoinTargetsResolved () &&
173
+ return u .Child .Resolved () &&
184
174
expression .ExpressionsResolved (u .checks .ToExpressions ()... ) &&
185
175
expression .ExpressionsResolved (u .Returning ... )
186
176
@@ -264,31 +254,97 @@ func (u *Update) WithJoinSchema(schema sql.Schema) *Update {
264
254
return & ret
265
255
}
266
256
267
- func (u * Update ) JoinUpdater () sql.RowUpdater {
268
- updaters := make ([]sql.RowUpdater , len (u .updateJoinTargets ))
269
- return & joinUpdater {
270
- updaters : updaters ,
271
- joinSchema : u .joinSchema ,
257
+ func (u * Update ) GetUpdaterAndSchema (ctx * sql.Context ) (sql.RowUpdater , sql.Schema , error ) {
258
+ if u .IsJoin {
259
+ updaterMap := make (map [string ]sql.RowUpdater )
260
+ for _ , target := range u .updateJoinTargets {
261
+ targetTable , err := GetUpdatable (target )
262
+ if err != nil {
263
+ return nil , nil , err
264
+ }
265
+ updaterMap [targetTable .Name ()] = targetTable .Updater (ctx )
266
+ }
267
+ return & joinUpdater {
268
+ updaterMap : updaterMap ,
269
+ schemaMap : RecreateTableSchemaFromJoinSchema (u .joinSchema ),
270
+ joinSchema : u .joinSchema ,
271
+ }, u .joinSchema , nil
272
272
}
273
+ updatable , err := GetUpdatable (u .Child )
274
+ if err != nil {
275
+ return nil , nil , err
276
+ }
277
+ return updatable .Updater (ctx ), updatable .Schema (), nil
273
278
}
274
279
275
280
type joinUpdater struct {
276
- updaters []sql.RowUpdater
281
+ updaterMap map [string ]sql.RowUpdater
282
+ schemaMap map [string ]sql.Schema
277
283
joinSchema sql.Schema
278
284
}
279
285
280
286
var _ sql.RowUpdater = (* joinUpdater )(nil )
281
287
282
- func (u * joinUpdater ) StatementBegin (ctx * sql.Context ) {}
288
+ // StatementBegins implements the sql.TableEditor interface
289
+ func (u * joinUpdater ) StatementBegin (ctx * sql.Context ) {
290
+ for _ , updater := range u .updaterMap {
291
+ updater .StatementBegin (ctx )
292
+ }
293
+ }
294
+
295
+ // DiscardChanges implements the sql.TableEditor interface
283
296
func (u * joinUpdater ) DiscardChanges (ctx * sql.Context , errorEncountered error ) error {
297
+ for _ , updater := range u .updaterMap {
298
+ err := updater .DiscardChanges (ctx , errorEncountered )
299
+ if err != nil {
300
+ return err
301
+ }
302
+ }
284
303
return nil
285
304
}
305
+
306
+ // StatementComplete implements the sql.TableEditor interface
286
307
func (u * joinUpdater ) StatementComplete (ctx * sql.Context ) error {
308
+ for _ , updater := range u .updaterMap {
309
+ err := updater .StatementComplete (ctx )
310
+ if err != nil {
311
+ return err
312
+ }
313
+ }
287
314
return nil
288
315
}
289
316
func (u * joinUpdater ) Update (ctx * sql.Context , old sql.Row , new sql.Row ) error {
317
+ tableToOldRowMap := SplitRowIntoTableRowMap (old , u .joinSchema )
318
+ tableToNewRowMap := SplitRowIntoTableRowMap (new , u .joinSchema )
319
+
320
+ for tableName , updater := range u .updaterMap {
321
+ oldRow := tableToOldRowMap [tableName ]
322
+ newRow := tableToNewRowMap [tableName ]
323
+ schema := u .schemaMap [tableName ]
324
+
325
+ eq , err := oldRow .Equals (ctx , newRow , schema )
326
+ if err != nil {
327
+ return err
328
+ }
329
+
330
+ if ! eq {
331
+ err = updater .Update (ctx , oldRow , newRow )
332
+ }
333
+
334
+ if err != nil {
335
+ return err
336
+ }
337
+ }
338
+
290
339
return nil
291
340
}
341
+
292
342
func (u * joinUpdater ) Close (ctx * sql.Context ) error {
343
+ for _ , updater := range u .updaterMap {
344
+ err := updater .Close (ctx )
345
+ if err != nil {
346
+ return err
347
+ }
348
+ }
293
349
return nil
294
350
}
0 commit comments