Skip to content

Commit 0c5a323

Browse files
author
James Cor
committed
mutable
1 parent 1cef4d6 commit 0c5a323

File tree

3 files changed

+44
-26
lines changed

3 files changed

+44
-26
lines changed

sql/rowexec/dml_iters.go

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -565,18 +565,10 @@ func getRowHandler(clientFoundRowsToggled bool, iter sql.RowIter) accumulatorRow
565565

566566
func AddAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Schema) {
567567
switch i := iter.(type) {
568-
case sql.CustomRowIter:
568+
case sql.MutableRowIter:
569569
childIter := i.GetChildIter()
570570
childIter, sch := AddAccumulatorIter(ctx, childIter)
571-
return i.SetChildIter(childIter), sch
572-
case *callIter:
573-
childIter, sch := AddAccumulatorIter(ctx, i.innerIter)
574-
i.innerIter = childIter
575-
return i, sch
576-
case *beginEndIter:
577-
childIter, sch := AddAccumulatorIter(ctx, i.rowIter)
578-
i.rowIter = childIter
579-
return i, sch
571+
return i.WithChildIter(childIter), sch
580572
default:
581573
clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) > 0
582574
rowHandler := getRowHandler(clientFoundRowsToggled, iter)

sql/rowexec/proc_iters.go

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ type beginEndIter struct {
6666
rowIter sql.RowIter
6767
}
6868

69-
var _ sql.RowIter = (*beginEndIter)(nil)
69+
var _ sql.MutableRowIter = (*beginEndIter)(nil)
7070

7171
// Next implements the interface sql.RowIter.
7272
func (b *beginEndIter) Next(ctx *sql.Context) (sql.Row, error) {
@@ -99,40 +99,54 @@ func (b *beginEndIter) Close(ctx *sql.Context) error {
9999
return b.rowIter.Close(ctx)
100100
}
101101

102+
// GetChildIter implements the sql.MutableRowIter interface.
103+
func (b *beginEndIter) GetChildIter() sql.RowIter {
104+
return b.rowIter
105+
}
106+
107+
// WithChildIter implements the sql.MutableRowIter interface.
108+
func (b *beginEndIter) WithChildIter(child sql.RowIter) sql.RowIter {
109+
nb := *b
110+
nb.rowIter = child
111+
return &nb
112+
}
113+
102114
// callIter is the row iterator for *Call.
103115
type callIter struct {
104116
call *plan.Call
105117
innerIter sql.RowIter
106118
}
107119

120+
var _ sql.MutableRowIter = (*callIter)(nil)
121+
108122
// Next implements the sql.RowIter interface.
109-
func (iter *callIter) Next(ctx *sql.Context) (sql.Row, error) {
110-
return iter.innerIter.Next(ctx)
123+
func (ci *callIter) Next(ctx *sql.Context) (sql.Row, error) {
124+
return ci.innerIter.Next(ctx)
111125
}
112126

113127
// Close implements the sql.RowIter interface.
114-
func (iter *callIter) Close(ctx *sql.Context) error {
115-
err := iter.innerIter.Close(ctx)
128+
func (ci *callIter) Close(ctx *sql.Context) error {
129+
err := ci.innerIter.Close(ctx)
116130
if err != nil {
117131
return err
118132
}
119-
err = iter.call.Pref.CloseAllCursors(ctx)
133+
err = ci.call.Pref.CloseAllCursors(ctx)
120134
if err != nil {
121135
return err
122136
}
123137

124138
// Set all user and system variables from INOUT and OUT params
125-
for i, param := range iter.call.Procedure.Params {
139+
for i, param := range ci.call.Procedure.Params {
126140
if param.Direction == plan.ProcedureParamDirection_Inout ||
127-
(param.Direction == plan.ProcedureParamDirection_Out && iter.call.Pref.VariableHasBeenSet(param.Name)) {
128-
val, err := iter.call.Pref.GetVariableValue(param.Name)
141+
(param.Direction == plan.ProcedureParamDirection_Out && ci.call.Pref.VariableHasBeenSet(param.Name)) {
142+
val, err := ci.call.Pref.GetVariableValue(param.Name)
129143
if err != nil {
130144
return err
131145
}
132146

133-
typ := iter.call.Pref.GetVariableType(param.Name)
147+
typ := ci.call.Pref.GetVariableType(param.Name)
134148

135-
switch callParam := iter.call.Params[i].(type) {
149+
switch callParam := ci.call.Params[i].(type) {
136150
case *expression.UserVar:
137151
err = ctx.SetUserVariable(ctx, callParam.Name, val, typ)
138152
if err != nil {
@@ -150,9 +164,9 @@ func (iter *callIter) Close(ctx *sql.Context) error {
150164
} else if param.Direction == plan.ProcedureParamDirection_Out { // VariableHasBeenSet was false
151165
// For OUT only, if a var was not set within the procedure body, then we set the vars to nil.
152166
// If the var had a value before the call then it is basically removed.
153-
switch callParam := iter.call.Params[i].(type) {
167+
switch callParam := ci.call.Params[i].(type) {
154168
case *expression.UserVar:
155-
err = ctx.SetUserVariable(ctx, callParam.Name, nil, iter.call.Pref.GetVariableType(param.Name))
169+
err = ctx.SetUserVariable(ctx, callParam.Name, nil, ci.call.Pref.GetVariableType(param.Name))
156170
if err != nil {
157171
return err
158172
}
@@ -170,6 +184,18 @@ func (iter *callIter) Close(ctx *sql.Context) error {
170184
return nil
171185
}
172186

187+
// GetChildIter implements the sql.MutableRowIter interface.
188+
func (ci *callIter) GetChildIter() sql.RowIter {
189+
return ci.innerIter
190+
}
191+
192+
// WithChildIter implements the sql.MutableRowIter interface.
193+
func (ci *callIter) WithChildIter(child sql.RowIter) sql.RowIter {
194+
nci := *ci
195+
nci.innerIter = child
196+
return &nci
197+
}
198+
173199
type elseCaseErrorIter struct{}
174200

175201
var _ sql.RowIter = elseCaseErrorIter{}

sql/rows.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,10 @@ func (i *sliceRowIter) Close(*Context) error {
202202
return nil
203203
}
204204

205-
// CustomRowIter is an extension of RowIter for integrators that wrap RowIters.
205+
// MutableRowIter is an extension of RowIter for integrators that wrap RowIters.
206206
// It allows for analysis rules to inspect the underlying RowIters.
207-
type CustomRowIter interface {
207+
type MutableRowIter interface {
208208
RowIter
209209
GetChildIter() RowIter
210-
SetChildIter(childIter RowIter) RowIter
210+
WithChildIter(childIter RowIter) RowIter
211211
}

0 commit comments

Comments
 (0)