Skip to content

Commit 3cefd2c

Browse files
author
James Cor
committed
more progress
1 parent b124b42 commit 3cefd2c

File tree

5 files changed

+63
-79
lines changed

5 files changed

+63
-79
lines changed

engine.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -898,7 +898,7 @@ func finalizeIters(ctx *sql.Context, analyzed sql.Node, qFlags *sql.QueryFlags,
898898
var err error
899899
var sch sql.Schema
900900
// TODO: if this is does something we need to overwrite the schema with types.OkResultSchema
901-
//iter, sch, err = rowexec.AddAccumulatorIter(ctx, analyzed, iter)
901+
iter, sch, err = rowexec.AddAccumulatorIter(ctx, analyzed, iter)
902902
if err != nil {
903903
return nil, nil, err
904904
}

enginetest/memory_engine_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,14 +202,18 @@ func TestSingleScript(t *testing.T) {
202202
{
203203
Name: "test script",
204204
SetUpScript: []string{
205-
"create table t (i int);",
205+
"create table t (i int primary key);",
206206
"create procedure p() begin insert into t values (1); end;",
207207
},
208208
Assertions: []queries.ScriptTestAssertion{
209209
{
210210
Query: "call p();",
211211
Expected: []sql.Row{},
212212
},
213+
//{
214+
// Query: "update t join (values row(1), row(2)) t2 (j) on t.i = t2.j set t.i = 10;",
215+
// Expected: []sql.Row{},
216+
//},
213217
},
214218
},
215219
}

sql/analyzer/apply_update_accumulators.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import (
2626
// applyUpdateAccumulators wraps any Insert, Update, or Delete nodes with RowUpdateAccumulators to tally the results
2727
// for report to the client.
2828
func applyUpdateAccumulators(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
29-
return n, transform.SameTree, nil
29+
//return n, transform.SameTree, nil
3030
switch n := n.(type) {
3131
case *plan.TriggerExecutor, *plan.InsertInto, *plan.DeleteFrom, *plan.Update:
3232
accumulatorType, err := getUpdateAccumulatorType(n)
@@ -41,7 +41,7 @@ func applyUpdateAccumulators(ctx *sql.Context, a *Analyzer, n sql.Node, scope *p
4141

4242
// getUpdateAccumulatorType returns the type of accumulator needed for the node given, or an error if there's no match.
4343
func getUpdateAccumulatorType(n sql.Node) (plan.RowUpdateType, error) {
44-
return -1, nil
44+
//return -1, nil
4545
switch n := n.(type) {
4646
case *plan.TriggerExecutor:
4747
return getUpdateAccumulatorType(n.Left())

sql/rowexec/dml_iters.go

Lines changed: 53 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -516,87 +516,66 @@ type accumulatorIter struct {
516516
updateRowHandler accumulatorRowHandler
517517
}
518518

519-
func AddAccumulatorIter(ctx *sql.Context, node sql.Node, iter sql.RowIter) (sql.RowIter, sql.Schema, error) {
520-
clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) == mysql.CapabilityClientFoundRows
521519

522-
var rowHandler accumulatorRowHandler
523-
switch n := node.(type) {
524-
case *plan.TriggerExecutor:
525-
return AddAccumulatorIter(ctx, n.Left(), iter)
526-
case *plan.InsertInto:
527-
if n.IsReplace {
528-
rowHandler = &replaceRowHandler{}
529-
} else if len(n.OnDupExprs) > 0 {
530-
rowHandler = &onDuplicateUpdateHandler{schema: n.Schema(), clientFoundRowsCapability: clientFoundRowsToggled}
531-
} else {
532-
rowHandler = &insertRowHandler{}
520+
func getRowHandler(ctx *sql.Context, clientFoundRowsToggled bool, iter sql.RowIter) accumulatorRowHandler {
521+
switch i := iter.(type) {
522+
case *plan.TableEditorIter:
523+
return getRowHandler(ctx, clientFoundRowsToggled, i.InnerIter())
524+
case *plan.CheckpointingTableEditorIter:
525+
return getRowHandler(ctx, clientFoundRowsToggled, i.InnerIter())
526+
case *ProjectIter:
527+
return getRowHandler(ctx, clientFoundRowsToggled, i.childIter)
528+
case *triggerIter:
529+
return getRowHandler(ctx, clientFoundRowsToggled, i.child)
530+
case *insertIter:
531+
if i.replacer != nil {
532+
return &replaceRowHandler{}
533533
}
534-
case *plan.DeleteFrom:
535-
rowHandler = &deleteRowHandler{}
536-
case *plan.Update:
537-
// search for a join
538-
hasJoin := false
539-
transform.Inspect(n, func(node sql.Node) bool {
540-
switch node.(type) {
541-
case *plan.JoinNode:
542-
hasJoin = true
543-
return false
544-
}
545-
return true
546-
})
547-
if hasJoin {
548-
var schema sql.Schema
549-
var updaterMap map[string]sql.RowUpdater
550-
transform.Inspect(n, func(node sql.Node) bool {
551-
switch nn := node.(type) {
552-
case *plan.JoinNode, *plan.Project:
553-
schema = node.Schema()
554-
return false
555-
case *plan.UpdateJoin:
556-
updaterMap = nn.Updaters
557-
return true
558-
default:
559-
return true
560-
}
561-
})
562-
if schema == nil {
563-
return nil, nil, fmt.Errorf("error: No JoinNode found in query plan to go along with an UpdateTypeJoinUpdate")
564-
}
565-
// assign row handler to updateJoinIter
566-
rowHandler = &updateJoinRowHandler{joinSchema: schema, tableMap: plan.RecreateTableSchemaFromJoinSchema(schema), updaterMap: updaterMap}
567-
for i, done := iter, false; !done; {
568-
switch ii := i.(type) {
569-
case *plan.TableEditorIter:
570-
i = ii.InnerIter()
571-
case *ProjectIter:
572-
i = ii.childIter
573-
case *plan.CheckpointingTableEditorIter:
574-
i = ii.InnerIter()
575-
case *triggerIter:
576-
i = ii.child
577-
case *updateIter:
578-
i = ii.childIter
579-
case *updateJoinIter:
580-
ii.accumulator = rowHandler.(*updateJoinRowHandler)
581-
done = true
582-
default:
583-
return nil, nil, fmt.Errorf("failed to apply rowHandler to updateJoin, unknown type: %T", iter)
584-
}
534+
if i.updater != nil {
535+
return &onDuplicateUpdateHandler{schema: i.schema, clientFoundRowsCapability: clientFoundRowsToggled}
536+
}
537+
return &insertRowHandler{}
538+
case *deleteIter:
539+
return &deleteRowHandler{}
540+
case *updateIter:
541+
if updateJoin, ok := i.childIter.(*updateJoinIter); ok {
542+
return &updateJoinRowHandler{
543+
joinSchema: updateJoin.joinSchema,
544+
tableMap: plan.RecreateTableSchemaFromJoinSchema(updateJoin.joinSchema),
545+
updaterMap: updateJoin.updaters,
585546
}
586-
} else {
587-
// the schema of the update node is a self-concatenation of the underlying table's, so split it in half
588-
// for new / old row comparison purposes
589-
sch := n.Schema()
590-
rowHandler = &updateRowHandler{schema: sch[:len(sch)/2], clientFoundRowsCapability: clientFoundRowsToggled}
591547
}
548+
return &updateRowHandler{schema: i.schema[:len(i.schema)/2], clientFoundRowsCapability: clientFoundRowsToggled}
592549
default:
593-
return iter, nil, nil
550+
return nil
594551
}
552+
}
595553

596-
return &accumulatorIter{
597-
iter: iter,
598-
updateRowHandler: rowHandler,
599-
}, types.OkResultSchema, nil
554+
func AddAccumulatorIter(ctx *sql.Context, node sql.Node, iter sql.RowIter) (sql.RowIter, sql.Schema, error) {
555+
switch i := iter.(type) {
556+
case *callIter:
557+
childIter, sch, err := AddAccumulatorIter(ctx, node, i.innerIter)
558+
i.innerIter = childIter
559+
return i, sch, err
560+
case *beginEndIter:
561+
childIter, sch, err := AddAccumulatorIter(ctx, node, i.rowIter)
562+
i.rowIter = childIter
563+
return i, sch, err
564+
case *blockIter:
565+
childIter, sch, err := AddAccumulatorIter(ctx, node, i.internalIter)
566+
i.internalIter = childIter
567+
return i, sch, err
568+
default:
569+
clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) == mysql.CapabilityClientFoundRows
570+
rowHandler := getRowHandler(ctx, clientFoundRowsToggled, iter)
571+
if rowHandler == nil {
572+
return iter, nil, nil
573+
}
574+
return &accumulatorIter{
575+
iter: iter,
576+
updateRowHandler: rowHandler,
577+
}, types.OkResultSchema, nil
578+
}
600579
}
601580

602581
func (a *accumulatorIter) Next(ctx *sql.Context) (r sql.Row, err error) {

sql/rowexec/other.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ func (b *BaseBuilder) buildBlock(ctx *sql.Context, n *plan.Block, row sql.Row) (
219219
returnSch = subIterSch
220220
} else if !selectSeen {
221221
returnNode = subIterNode
222-
returnSch = subIterSch
222+
returnSch = types.OkResultSchema
223+
//returnSch = subIterSch
223224
}
224225

225226
for {

0 commit comments

Comments
 (0)