Skip to content

Commit 5036c14

Browse files
committed
sql: move processSourceRow to updateRun, upsertRun, and deleteRun
Refactor processSourceRow so that it's a method of updateRun, upsertRun, and deleteRun, rather than a method of updateNode, upsertNode, and deleteNode, respectively. This matches insertRun. This refactor will make it easier for the new UPDATE and DELETE fast path planNodes to share updateRun and deleteRun with the existing updateNode and deleteNode. (There's no new fast path in the works for UPSERT, at least not yet, but upsertRun is changed for completeness.) Epic: None Release note: None
1 parent 4f8b677 commit 5036c14

File tree

3 files changed

+96
-83
lines changed

3 files changed

+96
-83
lines changed

pkg/sql/delete.go

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ type deleteNode struct {
3232
run deleteRun
3333
}
3434

35+
var _ mutationPlanNode = &deleteNode{}
36+
3537
// deleteRun contains the run-time state of deleteNode during local execution.
3638
type deleteRun struct {
3739
td tableDeleter
@@ -61,21 +63,26 @@ type deleteRun struct {
6163
numPassthrough int
6264
}
6365

64-
var _ mutationPlanNode = &deleteNode{}
66+
func (r *deleteRun) initRowContainer(params runParams, columns colinfo.ResultColumns) {
67+
if !r.rowsNeeded {
68+
return
69+
}
70+
r.td.rows = rowcontainer.NewRowContainer(
71+
params.p.Mon().MakeBoundAccount(),
72+
colinfo.ColTypeInfoFromResCols(columns),
73+
)
74+
r.resultRowBuffer = make([]tree.Datum, len(columns))
75+
for i := range r.resultRowBuffer {
76+
r.resultRowBuffer[i] = tree.DNull
77+
}
78+
}
6579

6680
func (d *deleteNode) startExec(params runParams) error {
6781
// cache traceKV during execution, to avoid re-evaluating it for every row.
6882
d.run.traceKV = params.p.ExtendedEvalContext().Tracing.KVTracingEnabled()
6983

70-
if d.run.rowsNeeded {
71-
d.run.td.rows = rowcontainer.NewRowContainer(
72-
params.p.Mon().MakeBoundAccount(),
73-
colinfo.ColTypeInfoFromResCols(d.columns))
74-
d.run.resultRowBuffer = make([]tree.Datum, len(d.columns))
75-
for i := range d.run.resultRowBuffer {
76-
d.run.resultRowBuffer[i] = tree.DNull
77-
}
78-
}
84+
d.run.initRowContainer(params, d.columns)
85+
7986
return d.run.td.init(params.ctx, params.p.txn, params.EvalContext())
8087
}
8188

@@ -115,7 +122,7 @@ func (d *deleteNode) BatchedNext(params runParams) (bool, error) {
115122

116123
// Process the deletion of the current input row,
117124
// potentially accumulating the result row for later.
118-
if err := d.processSourceRow(params, d.input.Values()); err != nil {
125+
if err := d.run.processSourceRow(params, d.input.Values()); err != nil {
119126
return false, err
120127
}
121128

@@ -153,18 +160,18 @@ func (d *deleteNode) BatchedNext(params runParams) (bool, error) {
153160

154161
// processSourceRow processes one row from the source for deletion and, if
155162
// result rows are needed, saves it in the result row container
156-
func (d *deleteNode) processSourceRow(params runParams, sourceVals tree.Datums) error {
163+
func (r *deleteRun) processSourceRow(params runParams, sourceVals tree.Datums) error {
157164
// Remove extra columns for partial index predicate values and AFTER triggers.
158-
deleteVals := sourceVals[:len(d.run.td.rd.FetchCols)+d.run.numPassthrough]
165+
deleteVals := sourceVals[:len(r.td.rd.FetchCols)+r.numPassthrough]
159166
sourceVals = sourceVals[len(deleteVals):]
160167

161168
// Create a set of partial index IDs to not delete from. Indexes should not
162169
// be deleted from when they are partial indexes and the row does not
163170
// satisfy the predicate and therefore do not exist in the partial index.
164171
// This set is passed as a argument to tableDeleter.row below.
165172
var pm row.PartialIndexUpdateHelper
166-
if n := len(d.run.td.tableDesc().PartialIndexes()); n > 0 {
167-
err := pm.Init(nil /* partialIndexPutVals */, sourceVals[:n], d.run.td.tableDesc())
173+
if n := len(r.td.tableDesc().PartialIndexes()); n > 0 {
174+
err := pm.Init(nil /* partialIndexPutVals */, sourceVals[:n], r.td.tableDesc())
168175
if err != nil {
169176
return err
170177
}
@@ -174,52 +181,52 @@ func (d *deleteNode) processSourceRow(params runParams, sourceVals tree.Datums)
174181
// Keep track of the vector index partitions to update. This information is
175182
// passed to tableInserter.row below.
176183
var vh row.VectorIndexUpdateHelper
177-
if n := len(d.run.td.tableDesc().VectorIndexes()); n > 0 {
178-
vh.InitForDel(sourceVals[:n], d.run.td.tableDesc())
184+
if n := len(r.td.tableDesc().VectorIndexes()); n > 0 {
185+
vh.InitForDel(sourceVals[:n], r.td.tableDesc())
179186
}
180187

181188
// Queue the deletion in the KV batch.
182-
if err := d.run.td.row(
183-
params.ctx, deleteVals, pm, vh, false /* mustValidateOldPKValues */, d.run.traceKV,
189+
if err := r.td.row(
190+
params.ctx, deleteVals, pm, vh, false /* mustValidateOldPKValues */, r.traceKV,
184191
); err != nil {
185192
return err
186193
}
187194

188195
// If result rows need to be accumulated, do it.
189-
if d.run.td.rows != nil {
196+
if r.td.rows != nil {
190197
// The new values can include all columns, so the values may contain
191198
// additional columns for every newly dropped column not visible. We do not
192199
// want them to be available for RETURNING.
193200
//
194-
// d.run.rows.NumCols() is guaranteed to only contain the requested
201+
// r.rows.NumCols() is guaranteed to only contain the requested
195202
// public columns.
196203
largestRetIdx := -1
197-
for i := range d.run.rowIdxToRetIdx {
198-
retIdx := d.run.rowIdxToRetIdx[i]
204+
for i := range r.rowIdxToRetIdx {
205+
retIdx := r.rowIdxToRetIdx[i]
199206
if retIdx >= 0 {
200207
if retIdx >= largestRetIdx {
201208
largestRetIdx = retIdx
202209
}
203-
d.run.resultRowBuffer[retIdx] = deleteVals[i]
210+
r.resultRowBuffer[retIdx] = deleteVals[i]
204211
}
205212
}
206213

207214
// At this point we've extracted all the RETURNING values that are part
208215
// of the target table. We must now extract the columns in the RETURNING
209216
// clause that refer to other tables (from the USING clause of the delete).
210-
if d.run.numPassthrough > 0 {
211-
passthroughBegin := len(d.run.td.rd.FetchCols)
212-
passthroughEnd := passthroughBegin + d.run.numPassthrough
217+
if r.numPassthrough > 0 {
218+
passthroughBegin := len(r.td.rd.FetchCols)
219+
passthroughEnd := passthroughBegin + r.numPassthrough
213220
passthroughValues := deleteVals[passthroughBegin:passthroughEnd]
214221

215-
for i := 0; i < d.run.numPassthrough; i++ {
222+
for i := 0; i < r.numPassthrough; i++ {
216223
largestRetIdx++
217-
d.run.resultRowBuffer[largestRetIdx] = passthroughValues[i]
224+
r.resultRowBuffer[largestRetIdx] = passthroughValues[i]
218225
}
219226

220227
}
221228

222-
if _, err := d.run.td.rows.AddRow(params.ctx, d.run.resultRowBuffer); err != nil {
229+
if _, err := r.td.rows.AddRow(params.ctx, r.resultRowBuffer); err != nil {
223230
return err
224231
}
225232
}

pkg/sql/update.go

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -72,20 +72,26 @@ type updateRun struct {
7272
regionLocalInfo regionLocalInfoType
7373
}
7474

75+
func (r *updateRun) initRowContainer(params runParams, columns colinfo.ResultColumns) {
76+
if !r.rowsNeeded {
77+
return
78+
}
79+
r.tu.rows = rowcontainer.NewRowContainer(
80+
params.p.Mon().MakeBoundAccount(),
81+
colinfo.ColTypeInfoFromResCols(columns),
82+
)
83+
r.resultRowBuffer = make([]tree.Datum, len(columns))
84+
for i := range r.resultRowBuffer {
85+
r.resultRowBuffer[i] = tree.DNull
86+
}
87+
}
88+
7589
func (u *updateNode) startExec(params runParams) error {
7690
// cache traceKV during execution, to avoid re-evaluating it for every row.
7791
u.run.traceKV = params.p.ExtendedEvalContext().Tracing.KVTracingEnabled()
7892

79-
if u.run.rowsNeeded {
80-
u.run.tu.rows = rowcontainer.NewRowContainer(
81-
params.p.Mon().MakeBoundAccount(),
82-
colinfo.ColTypeInfoFromResCols(u.columns),
83-
)
84-
u.run.resultRowBuffer = make([]tree.Datum, len(u.columns))
85-
for i := range u.run.resultRowBuffer {
86-
u.run.resultRowBuffer[i] = tree.DNull
87-
}
88-
}
93+
u.run.initRowContainer(params, u.columns)
94+
8995
return u.run.tu.init(params.ctx, params.p.txn, params.EvalContext())
9096
}
9197

@@ -126,7 +132,7 @@ func (u *updateNode) BatchedNext(params runParams) (bool, error) {
126132

127133
// Process the update for the current input row, potentially
128134
// accumulating the result row for later.
129-
if err := u.processSourceRow(params, u.input.Values()); err != nil {
135+
if err := u.run.processSourceRow(params, u.input.Values()); err != nil {
130136
return false, err
131137
}
132138

@@ -164,7 +170,7 @@ func (u *updateNode) BatchedNext(params runParams) (bool, error) {
164170

165171
// processSourceRow processes one row from the source for update and, if
166172
// result rows are needed, saves it in the result row container.
167-
func (u *updateNode) processSourceRow(params runParams, sourceVals tree.Datums) error {
173+
func (r *updateRun) processSourceRow(params runParams, sourceVals tree.Datums) error {
168174
// sourceVals contains values for the columns from the table, in the order of the
169175
// table descriptor. (One per column in u.tw.ru.FetchCols)
170176
//
@@ -174,42 +180,42 @@ func (u *updateNode) processSourceRow(params runParams, sourceVals tree.Datums)
174180
// oldValues is the prefix of sourceVals that corresponds to real
175181
// stored columns in the table, that is, excluding the RHS assignment
176182
// expressions.
177-
oldValues := sourceVals[:len(u.run.tu.ru.FetchCols)]
183+
oldValues := sourceVals[:len(r.tu.ru.FetchCols)]
178184
sourceVals = sourceVals[len(oldValues):]
179185

180186
// The update values follow the fetch values and their order corresponds to the order of ru.UpdateCols.
181-
updateValues := sourceVals[:len(u.run.tu.ru.UpdateCols)]
187+
updateValues := sourceVals[:len(r.tu.ru.UpdateCols)]
182188
sourceVals = sourceVals[len(updateValues):]
183189

184190
// The passthrough values follow the update values.
185-
passthroughValues := sourceVals[:u.run.numPassthrough]
191+
passthroughValues := sourceVals[:r.numPassthrough]
186192
sourceVals = sourceVals[len(passthroughValues):]
187193

188194
// Verify the schema constraints. For consistency with INSERT/UPSERT
189195
// and compatibility with PostgreSQL, we must do this before
190196
// processing the CHECK constraints.
191-
if err := enforceNotNullConstraints(updateValues, u.run.tu.ru.UpdateCols); err != nil {
197+
if err := enforceNotNullConstraints(updateValues, r.tu.ru.UpdateCols); err != nil {
192198
return err
193199
}
194200

195201
// Run the CHECK constraints, if any. CheckHelper will either evaluate the
196202
// constraints itself, or else inspect boolean columns from the input that
197203
// contain the results of evaluation.
198-
if !u.run.checkOrds.Empty() {
204+
if !r.checkOrds.Empty() {
199205
if err := checkMutationInput(
200206
params.ctx, params.EvalContext(), &params.p.semaCtx, params.p.SessionData(),
201-
u.run.tu.tableDesc(), u.run.checkOrds, sourceVals[:u.run.checkOrds.Len()],
207+
r.tu.tableDesc(), r.checkOrds, sourceVals[:r.checkOrds.Len()],
202208
); err != nil {
203209
return err
204210
}
205-
sourceVals = sourceVals[u.run.checkOrds.Len():]
211+
sourceVals = sourceVals[r.checkOrds.Len():]
206212
}
207213

208214
// Create a set of partial index IDs to not add entries or remove entries
209215
// from. Put values are followed by del values.
210216
var pm row.PartialIndexUpdateHelper
211-
if n := len(u.run.tu.tableDesc().PartialIndexes()); n > 0 {
212-
err := pm.Init(sourceVals[:n], sourceVals[n:n*2], u.run.tu.tableDesc())
217+
if n := len(r.tu.tableDesc().PartialIndexes()); n > 0 {
218+
err := pm.Init(sourceVals[:n], sourceVals[n:n*2], r.tu.tableDesc())
213219
if err != nil {
214220
return err
215221
}
@@ -221,53 +227,53 @@ func (u *updateNode) processSourceRow(params runParams, sourceVals tree.Datums)
221227
// Order of column values is put partitions, quantized vectors, followed by
222228
// del partitions
223229
var vh row.VectorIndexUpdateHelper
224-
if n := len(u.run.tu.tableDesc().VectorIndexes()); n > 0 {
225-
vh.InitForPut(sourceVals[:n], sourceVals[n:n*2], u.run.tu.tableDesc())
226-
vh.InitForDel(sourceVals[n*2:n*3], u.run.tu.tableDesc())
230+
if n := len(r.tu.tableDesc().VectorIndexes()); n > 0 {
231+
vh.InitForPut(sourceVals[:n], sourceVals[n:n*2], r.tu.tableDesc())
232+
vh.InitForDel(sourceVals[n*2:n*3], r.tu.tableDesc())
227233
}
228234

229235
// Error out the update if the enforce_home_region session setting is on and
230236
// the row's locality doesn't match the gateway region.
231-
if err := u.run.regionLocalInfo.checkHomeRegion(updateValues); err != nil {
237+
if err := r.regionLocalInfo.checkHomeRegion(updateValues); err != nil {
232238
return err
233239
}
234240

235241
// Queue the insert in the KV batch.
236-
newValues, err := u.run.tu.rowForUpdate(
237-
params.ctx, oldValues, updateValues, pm, vh, false /* mustValidateOldPKValues */, u.run.traceKV,
242+
newValues, err := r.tu.rowForUpdate(
243+
params.ctx, oldValues, updateValues, pm, vh, false /* mustValidateOldPKValues */, r.traceKV,
238244
)
239245
if err != nil {
240246
return err
241247
}
242248

243249
// If result rows need to be accumulated, do it.
244-
if u.run.tu.rows != nil {
250+
if r.tu.rows != nil {
245251
// The new values can include all columns, so the values may contain
246252
// additional columns for every newly added column not yet visible. We do
247253
// not want them to be available for RETURNING.
248254
//
249255
// MakeUpdater guarantees that the first columns of the new values
250256
// are those specified u.columns.
251257
largestRetIdx := -1
252-
for i := range u.run.rowIdxToRetIdx {
253-
retIdx := u.run.rowIdxToRetIdx[i]
258+
for i := range r.rowIdxToRetIdx {
259+
retIdx := r.rowIdxToRetIdx[i]
254260
if retIdx >= 0 {
255261
if retIdx >= largestRetIdx {
256262
largestRetIdx = retIdx
257263
}
258-
u.run.resultRowBuffer[retIdx] = newValues[i]
264+
r.resultRowBuffer[retIdx] = newValues[i]
259265
}
260266
}
261267

262268
// At this point we've extracted all the RETURNING values that are part
263269
// of the target table. We must now extract the columns in the RETURNING
264270
// clause that refer to other tables (from the FROM clause of the update).
265-
for i := 0; i < u.run.numPassthrough; i++ {
271+
for i := 0; i < r.numPassthrough; i++ {
266272
largestRetIdx++
267-
u.run.resultRowBuffer[largestRetIdx] = passthroughValues[i]
273+
r.resultRowBuffer[largestRetIdx] = passthroughValues[i]
268274
}
269275

270-
if _, err := u.run.tu.rows.AddRow(params.ctx, u.run.resultRowBuffer); err != nil {
276+
if _, err := r.tu.rows.AddRow(params.ctx, r.resultRowBuffer); err != nil {
271277
return err
272278
}
273279
}

0 commit comments

Comments
 (0)