Skip to content

Commit 4329119

Browse files
author
James Cor
committed
moved rowupdateaccumulator
1 parent 3cefd2c commit 4329119

File tree

5 files changed

+45
-28
lines changed

5 files changed

+45
-28
lines changed

sql/analyzer/apply_update_accumulators.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import (
2626
// applyUpdateAccumulators wraps any Insert, Update, or Delete nodes with RowUpdateAccumulators to tally the results
2727
// for report to the client.
2828
func applyUpdateAccumulators(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
29-
//return n, transform.SameTree, nil
29+
return n, transform.SameTree, nil
3030
switch n := n.(type) {
3131
case *plan.TriggerExecutor, *plan.InsertInto, *plan.DeleteFrom, *plan.Update:
3232
accumulatorType, err := getUpdateAccumulatorType(n)

sql/rowexec/dml_iters.go

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -516,17 +516,18 @@ type accumulatorIter struct {
516516
updateRowHandler accumulatorRowHandler
517517
}
518518

519-
520-
func getRowHandler(ctx *sql.Context, clientFoundRowsToggled bool, iter sql.RowIter) accumulatorRowHandler {
519+
func getRowHandler(clientFoundRowsToggled bool, iter sql.RowIter) accumulatorRowHandler {
521520
switch i := iter.(type) {
522521
case *plan.TableEditorIter:
523-
return getRowHandler(ctx, clientFoundRowsToggled, i.InnerIter())
522+
return getRowHandler(clientFoundRowsToggled, i.InnerIter())
524523
case *plan.CheckpointingTableEditorIter:
525-
return getRowHandler(ctx, clientFoundRowsToggled, i.InnerIter())
524+
return getRowHandler(clientFoundRowsToggled, i.InnerIter())
526525
case *ProjectIter:
527-
return getRowHandler(ctx, clientFoundRowsToggled, i.childIter)
526+
return getRowHandler(clientFoundRowsToggled, i.childIter)
528527
case *triggerIter:
529-
return getRowHandler(ctx, clientFoundRowsToggled, i.child)
528+
return getRowHandler(clientFoundRowsToggled, i.child)
529+
case *blockIter:
530+
return getRowHandler(clientFoundRowsToggled, i.repIter)
530531
case *insertIter:
531532
if i.replacer != nil {
532533
return &replaceRowHandler{}
@@ -538,14 +539,25 @@ func getRowHandler(ctx *sql.Context, clientFoundRowsToggled bool, iter sql.RowIt
538539
case *deleteIter:
539540
return &deleteRowHandler{}
540541
case *updateIter:
541-
if updateJoin, ok := i.childIter.(*updateJoinIter); ok {
542-
return &updateJoinRowHandler{
543-
joinSchema: updateJoin.joinSchema,
544-
tableMap: plan.RecreateTableSchemaFromJoinSchema(updateJoin.joinSchema),
545-
updaterMap: updateJoin.updaters,
546-
}
542+
// it's possible that there's an updateJoinIter that's not the immediate child of updateIter
543+
rowHandler := getRowHandler(clientFoundRowsToggled, i.childIter)
544+
if rowHandler != nil {
545+
return rowHandler
546+
}
547+
sch := i.schema
548+
// special case for foreign keys, plan.ForeignKeyHandler.Schema() returns original schema
549+
if fkHandler, isFk := i.updater.(*plan.ForeignKeyHandler); isFk {
550+
sch = fkHandler.Sch
547551
}
548-
return &updateRowHandler{schema: i.schema[:len(i.schema)/2], clientFoundRowsCapability: clientFoundRowsToggled}
552+
return &updateRowHandler{schema: sch, clientFoundRowsCapability: clientFoundRowsToggled}
553+
case *updateJoinIter:
554+
rowHandler := &updateJoinRowHandler{
555+
joinSchema: i.joinSchema,
556+
tableMap: plan.RecreateTableSchemaFromJoinSchema(i.joinSchema),
557+
updaterMap: i.updaters,
558+
}
559+
i.accumulator = rowHandler
560+
return rowHandler
549561
default:
550562
return nil
551563
}
@@ -561,13 +573,9 @@ func AddAccumulatorIter(ctx *sql.Context, node sql.Node, iter sql.RowIter) (sql.
561573
childIter, sch, err := AddAccumulatorIter(ctx, node, i.rowIter)
562574
i.rowIter = childIter
563575
return i, sch, err
564-
case *blockIter:
565-
childIter, sch, err := AddAccumulatorIter(ctx, node, i.internalIter)
566-
i.internalIter = childIter
567-
return i, sch, err
568576
default:
569-
clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) == mysql.CapabilityClientFoundRows
570-
rowHandler := getRowHandler(ctx, clientFoundRowsToggled, iter)
577+
clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) > 0
578+
rowHandler := getRowHandler(clientFoundRowsToggled, iter)
571579
if rowHandler == nil {
572580
return iter, nil, nil
573581
}

sql/rowexec/other.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ func (b *BaseBuilder) buildCachedResults(ctx *sql.Context, n *plan.CachedResults
148148
func (b *BaseBuilder) buildBlock(ctx *sql.Context, n *plan.Block, row sql.Row) (sql.RowIter, error) {
149149
var returnRows []sql.Row
150150
var returnNode sql.Node
151-
var returnSch sql.Schema
151+
var returnIter sql.RowIter
152+
var returnSch sql.Schema
152153

153154
selectSeen := false
154155
for _, s := range n.Children() {
@@ -216,10 +217,12 @@ func (b *BaseBuilder) buildBlock(ctx *sql.Context, n *plan.Block, row sql.Row) (
216217
if isSelect = plan.NodeRepresentsSelect(subIterNode); isSelect {
217218
selectSeen = true
218219
returnNode = subIterNode
219-
returnSch = subIterSch
220+
returnIter = subIter
221+
returnSch = subIterSch
220222
} else if !selectSeen {
221223
returnNode = subIterNode
222-
returnSch = types.OkResultSchema
224+
returnIter = subIter
225+
returnSch = types.OkResultSchema
223226
//returnSch = subIterSch
224227
}
225228

@@ -259,7 +262,8 @@ func (b *BaseBuilder) buildBlock(ctx *sql.Context, n *plan.Block, row sql.Row) (
259262
return &blockIter{
260263
internalIter: sql.RowsToRowIter(returnRows...),
261264
repNode: returnNode,
262-
sch: returnSch,
265+
repIter: returnIter,
266+
repSch: returnSch,
263267
}, nil
264268
}
265269

sql/rowexec/other_iters.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ func (itr *dropHistogramIter) Close(_ *sql.Context) error {
115115
type blockIter struct {
116116
internalIter sql.RowIter
117117
repNode sql.Node
118-
sch sql.Schema
118+
repIter sql.RowIter
119+
repSch sql.Schema
119120
}
120121

121122
var _ plan.BlockRowIter = (*blockIter)(nil)
@@ -137,7 +138,7 @@ func (i *blockIter) RepresentingNode() sql.Node {
137138

138139
// Schema implements the sql.BlockRowIter interface.
139140
func (i *blockIter) Schema() sql.Schema {
140-
return i.sch
141+
return i.repSch
141142
}
142143

143144
type prependRowIter struct {

sql/rowexec/proc.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ func (b *BaseBuilder) buildLoop(ctx *sql.Context, n *plan.Loop, row sql.Row) (sq
232232

233233
var returnRows []sql.Row
234234
var returnNode sql.Node
235-
var returnSch sql.Schema
235+
var returnIter sql.RowIter
236+
var returnSch sql.Schema
236237
selectSeen := false
237238

238239
// It's technically valid to make an infinite loop, but we don't want to actually allow that
@@ -285,10 +286,12 @@ func (b *BaseBuilder) buildLoop(ctx *sql.Context, n *plan.Loop, row sql.Row) (sq
285286
selectSeen = true
286287
includeResultSet = true
287288
returnNode = subIterNode
289+
returnIter = loopBodyIter
288290
returnSch = subIterSch
289291
} else if !selectSeen {
290292
includeResultSet = true
291293
returnNode = subIterNode
294+
returnIter = loopBodyIter
292295
returnSch = subIterSch
293296
}
294297
}
@@ -336,7 +339,8 @@ func (b *BaseBuilder) buildLoop(ctx *sql.Context, n *plan.Loop, row sql.Row) (sq
336339
return &blockIter{
337340
internalIter: sql.RowsToRowIter(returnRows...),
338341
repNode: returnNode,
339-
sch: returnSch,
342+
repSch: returnSch,
343+
repIter: returnIter,
340344
}, nil
341345
}
342346

0 commit comments

Comments
 (0)