Skip to content

Commit f43dafa

Browse files
authored
Optimization: Defer Projections for Server Queries (#2676)
1 parent d73dd77 commit f43dafa

29 files changed

+296
-159
lines changed

engine.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ func (e *Engine) AnalyzeQuery(
210210
query string,
211211
) (sql.Node, error) {
212212
binder := planbuilder.New(ctx, e.Analyzer.Catalog, e.Parser)
213-
parsed, _, _, qFlags, err := binder.Parse(query, false)
213+
parsed, _, _, qFlags, err := binder.Parse(query, nil, false)
214214
if err != nil {
215215
return nil, err
216216
}
@@ -238,7 +238,7 @@ func (e *Engine) PrepareParsedQuery(
238238
stmt sqlparser.Statement,
239239
) (sql.Node, error) {
240240
binder := planbuilder.New(ctx, e.Analyzer.Catalog, e.Parser)
241-
node, _, err := binder.BindOnly(stmt, query)
241+
node, _, err := binder.BindOnly(stmt, query, nil)
242242

243243
if err != nil {
244244
return nil, err
@@ -586,7 +586,7 @@ func (e *Engine) bindQuery(ctx *sql.Context, query string, parsed sqlparser.Stat
586586
var bound sql.Node
587587
var err error
588588
if parsed == nil {
589-
bound, _, _, qFlags, err = binder.Parse(query, false)
589+
bound, _, _, qFlags, err = binder.Parse(query, qFlags, false)
590590
if err != nil {
591591
clearAutocommitErr := clearAutocommitTransaction(ctx)
592592
if clearAutocommitErr != nil {
@@ -595,7 +595,7 @@ func (e *Engine) bindQuery(ctx *sql.Context, query string, parsed sqlparser.Stat
595595
return nil, nil, err
596596
}
597597
} else {
598-
bound, qFlags, err = binder.BindOnly(parsed, query)
598+
bound, qFlags, err = binder.BindOnly(parsed, query, qFlags)
599599
if err != nil {
600600
return nil, nil, err
601601
}
@@ -651,7 +651,7 @@ func (e *Engine) bindExecuteQueryNode(ctx *sql.Context, query string, eq *plan.E
651651
binder.SetBindingsWithExpr(tempBindings)
652652
}
653653

654-
bound, _, err := binder.BindOnly(prep, query)
654+
bound, _, err := binder.BindOnly(prep, query, nil)
655655
if err != nil {
656656
clearAutocommitErr := clearAutocommitTransaction(ctx)
657657
if clearAutocommitErr != nil {

enginetest/engine_only_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ func TestAnalyzer_Exp(t *testing.T) {
511511

512512
ctx := enginetest.NewContext(harness)
513513
b := planbuilder.New(ctx, e.EngineAnalyzer().Catalog, sql.NewMysqlParser())
514-
parsed, _, _, _, err := b.Parse(tt.query, false)
514+
parsed, _, _, _, err := b.Parse(tt.query, nil, false)
515515
require.NoError(t, err)
516516

517517
analyzed, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, nil)

enginetest/enginetests.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5653,11 +5653,11 @@ func TestTypesOverWire(t *testing.T, harness ClientHarness, sessionBuilder serve
56535653
break
56545654
}
56555655
expectedEngineRow := make([]*string, len(engineRow))
5656-
for i := range engineRow {
5657-
sqlVal, err := sch[i].Type.SQL(ctx, nil, engineRow[i])
5658-
if !assert.NoError(t, err) {
5659-
break
5660-
}
5656+
row, err := server.RowToSQL(ctx, sch, engineRow, nil)
5657+
if !assert.NoError(t, err) {
5658+
break
5659+
}
5660+
for i, sqlVal := range row {
56615661
if !sqlVal.IsNull() {
56625662
str := sqlVal.ToString()
56635663
expectedEngineRow[i] = &str

enginetest/evaluation.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ func injectBindVarsAndPrepare(
528528

529529
b := planbuilder.New(ctx, e.EngineAnalyzer().Catalog, sql.NewMysqlParser())
530530
b.SetParserOptions(sql.LoadSqlMode(ctx).ParserOptions())
531-
resPlan, _, err := b.BindOnly(parsed, q)
531+
resPlan, _, err := b.BindOnly(parsed, q, nil)
532532
if err != nil {
533533
return q, nil, err
534534
}

enginetest/plangen/cmd/plangen/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ func generatePlansForSuite(spec PlanSpec, w *bytes.Buffer) error {
166166
if !tt.Skip {
167167
ctx := enginetest.NewContextWithEngine(harness, engine)
168168
binder := planbuilder.New(ctx, engine.EngineAnalyzer().Catalog, sql.NewMysqlParser())
169-
parsed, _, _, qFlags, err := binder.Parse(tt.Query, false)
169+
parsed, _, _, qFlags, err := binder.Parse(tt.Query, nil, false)
170170
if err != nil {
171171
exit(fmt.Errorf("%w\nfailed to parse query: %s", err, tt.Query))
172172
}

server/handler.go

Lines changed: 77 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ import (
4040
"github.com/dolthub/go-mysql-server/internal/sockstate"
4141
"github.com/dolthub/go-mysql-server/sql"
4242
"github.com/dolthub/go-mysql-server/sql/analyzer"
43+
"github.com/dolthub/go-mysql-server/sql/iters"
4344
"github.com/dolthub/go-mysql-server/sql/plan"
45+
"github.com/dolthub/go-mysql-server/sql/rowexec"
4446
"github.com/dolthub/go-mysql-server/sql/types"
4547
)
4648

@@ -218,7 +220,7 @@ func (h *Handler) ComExecuteBound(ctx context.Context, conn *mysql.Conn, query s
218220
func (h *Handler) ComStmtExecute(ctx context.Context, c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error {
219221
_, err := h.errorWrappedDoQuery(ctx, c, prepare.PrepareStmt, nil, MultiStmtModeOff, prepare.BindVars, func(res *sqltypes.Result, more bool) error {
220222
return callback(res)
221-
}, nil)
223+
}, &sql.QueryFlags{})
222224
return err
223225
}
224226

@@ -295,7 +297,7 @@ func (h *Handler) ComMultiQuery(
295297
query string,
296298
callback mysql.ResultSpoolFn,
297299
) (string, error) {
298-
return h.errorWrappedDoQuery(ctx, c, query, nil, MultiStmtModeOn, nil, callback, nil)
300+
return h.errorWrappedDoQuery(ctx, c, query, nil, MultiStmtModeOn, nil, callback, &sql.QueryFlags{})
299301
}
300302

301303
// ComQuery executes a SQL query on the SQLe engine.
@@ -305,7 +307,7 @@ func (h *Handler) ComQuery(
305307
query string,
306308
callback mysql.ResultSpoolFn,
307309
) error {
308-
_, err := h.errorWrappedDoQuery(ctx, c, query, nil, MultiStmtModeOff, nil, callback, nil)
310+
_, err := h.errorWrappedDoQuery(ctx, c, query, nil, MultiStmtModeOff, nil, callback, &sql.QueryFlags{})
309311
return err
310312
}
311313

@@ -317,7 +319,7 @@ func (h *Handler) ComParsedQuery(
317319
parsed sqlparser.Statement,
318320
callback mysql.ResultSpoolFn,
319321
) error {
320-
_, err := h.errorWrappedDoQuery(ctx, c, query, parsed, MultiStmtModeOff, nil, callback, nil)
322+
_, err := h.errorWrappedDoQuery(ctx, c, query, parsed, MultiStmtModeOff, nil, callback, &sql.QueryFlags{})
321323
return err
322324
}
323325

@@ -424,6 +426,7 @@ func (h *Handler) doQuery(
424426
}
425427
}()
426428

429+
qFlags.Set(sql.QFlagDeferProjections)
427430
schema, rowIter, qFlags, err := queryExec(sqlCtx, query, parsed, analyzedPlan, bindings, qFlags)
428431
if err != nil {
429432
sqlCtx.GetLogger().WithError(err).Warn("error running query")
@@ -511,6 +514,37 @@ func resultForEmptyIter(ctx *sql.Context, iter sql.RowIter, resultFields []*quer
511514
return &sqltypes.Result{Fields: resultFields}, nil
512515
}
513516

517+
// GetDeferredProjections looks for a top-level deferred projection, retrieves its projections, and removes it from the
518+
// iterator tree.
519+
func GetDeferredProjections(iter sql.RowIter) (sql.RowIter, []sql.Expression) {
520+
switch i := iter.(type) {
521+
case *rowexec.ExprCloserIter:
522+
_, projs := GetDeferredProjections(i.GetIter())
523+
return i, projs
524+
case *plan.TrackedRowIter:
525+
_, projs := GetDeferredProjections(i.GetIter())
526+
return i, projs
527+
case *rowexec.TransactionCommittingIter:
528+
newChild, projs := GetDeferredProjections(i.GetIter())
529+
if projs != nil {
530+
i.WithChildIter(newChild)
531+
}
532+
return i, projs
533+
case *iters.LimitIter:
534+
newChild, projs := GetDeferredProjections(i.ChildIter)
535+
if projs != nil {
536+
i.ChildIter = newChild
537+
}
538+
return i, projs
539+
case *rowexec.ProjectIter:
540+
if i.CanDefer() {
541+
return i.GetChildIter(), i.GetProjections()
542+
}
543+
return i, nil
544+
}
545+
return iter, nil
546+
}
547+
514548
// resultForMax1RowIter ensures that an empty iterator returns at most one row
515549
func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, resultFields []*querypb.Field) (*sqltypes.Result, error) {
516550
defer trace.StartRegion(ctx, "Handler.resultForMax1RowIter").End()
@@ -527,8 +561,7 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter,
527561
if err := iter.Close(ctx); err != nil {
528562
return nil, err
529563
}
530-
531-
outputRow, err := rowToSQL(ctx, schema, row)
564+
outputRow, err := RowToSQL(ctx, schema, row, nil)
532565
if err != nil {
533566
return nil, err
534567
}
@@ -558,16 +591,11 @@ func (h *Handler) resultForDefaultIter(
558591
}
559592
}
560593

561-
pollCtx, cancelF := ctx.NewSubContext()
562-
eg.Go(func() error {
563-
defer pan2err()
564-
return h.pollForClosedConnection(pollCtx, c)
565-
})
566-
567594
wg := sync.WaitGroup{}
568595
wg.Add(2)
569596

570597
// Read rows off the row iterator and send them to the row channel.
598+
iter, projs := GetDeferredProjections(iter)
571599
var rowChan = make(chan sql.Row, 512)
572600
eg.Go(func() error {
573601
defer pan2err()
@@ -594,6 +622,12 @@ func (h *Handler) resultForDefaultIter(
594622
}
595623
})
596624

625+
pollCtx, cancelF := ctx.NewSubContext()
626+
eg.Go(func() error {
627+
defer pan2err()
628+
return h.pollForClosedConnection(pollCtx, c)
629+
})
630+
597631
// Default waitTime is one minute if there is no timeout configured, in which case
598632
// it will loop to iterate again unless the socket died by the OS timeout or other problems.
599633
// If there is a timeout, it will be enforced to ensure that Vitess has a chance to
@@ -639,7 +673,7 @@ func (h *Handler) resultForDefaultIter(
639673
continue
640674
}
641675

642-
outputRow, err := rowToSQL(ctx, schema, row)
676+
outputRow, err := RowToSQL(ctx, schema, row, projs)
643677
if err != nil {
644678
return err
645679
}
@@ -648,6 +682,7 @@ func (h *Handler) resultForDefaultIter(
648682
r.Rows = append(r.Rows, outputRow)
649683
r.RowsAffected++
650684
case <-timer.C:
685+
// TODO: timer should probably go in its own thread, as rowChan is blocking
651686
if h.readTimeout != 0 {
652687
// Cancel and return so Vitess can call the CloseConnection callback
653688
ctx.GetLogger().Tracef("connection timeout")
@@ -901,25 +936,43 @@ func updateMaxUsedConnectionsStatusVariable() {
901936
}()
902937
}
903938

904-
func rowToSQL(ctx *sql.Context, s sql.Schema, row sql.Row) ([]sqltypes.Value, error) {
905-
o := make([]sqltypes.Value, len(row))
939+
func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Expression) ([]sqltypes.Value, error) {
906940
// need to make sure the schema is not null as some plan schema is defined as null (e.g. IfElseBlock)
907-
if len(s) == 0 {
908-
return o, nil
941+
if len(sch) == 0 {
942+
return []sqltypes.Value{}, nil
909943
}
910-
var err error
911-
for i, v := range row {
912-
if v == nil {
913-
o[i] = sqltypes.NULL
944+
945+
outVals := make([]sqltypes.Value, len(sch))
946+
if len(projs) == 0 {
947+
for i, col := range sch {
948+
if row[i] == nil {
949+
outVals[i] = sqltypes.NULL
950+
continue
951+
}
952+
var err error
953+
outVals[i], err = col.Type.SQL(ctx, nil, row[i])
954+
if err != nil {
955+
return nil, err
956+
}
957+
}
958+
return outVals, nil
959+
}
960+
961+
for i, col := range sch {
962+
field, err := projs[i].Eval(ctx, row)
963+
if err != nil {
964+
return nil, err
965+
}
966+
if field == nil {
967+
outVals[i] = sqltypes.NULL
914968
continue
915969
}
916-
o[i], err = s[i].Type.SQL(ctx, nil, v)
970+
outVals[i], err = col.Type.SQL(ctx, nil, field)
917971
if err != nil {
918972
return nil, err
919973
}
920974
}
921-
922-
return o, nil
975+
return outVals, nil
923976
}
924977

925978
func schemaToFields(ctx *sql.Context, s sql.Schema) []*querypb.Field {

sql/analyzer/optimization_rules_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ func TestPushNotFilters(t *testing.T) {
223223
for _, tt := range tests {
224224
t.Run(tt.in, func(t *testing.T) {
225225
q := fmt.Sprintf("SELECT 1 from xy WHERE %s", tt.in)
226-
node, _, _, _, err := b.Parse(q, false)
226+
node, _, _, _, err := b.Parse(q, nil, false)
227227
require.NoError(t, err)
228228

229229
cmp, _, err := pushNotFilters(ctx, nil, node, nil, nil, nil)

sql/analyzer/rules.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ var OnceBeforeDefault = []Rule{
3939
{applyDefaultSelectLimitId, applyDefaultSelectLimit},
4040
{replaceCountStarId, replaceCountStar},
4141
{applyEventSchedulerId, applyEventScheduler},
42-
{validateOffsetAndLimitId, validateLimitAndOffset},
42+
{validateOffsetAndLimitId, validateOffsetAndLimit},
4343
{validateCreateTableId, validateCreateTable},
4444
{validateAlterTableId, validateAlterTable},
4545
{validateExprSemId, validateExprSem},

sql/analyzer/stored_procedures.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func loadStoredProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan
5656
var parsedProcedure sql.Node
5757
b := planbuilder.New(ctx, a.Catalog, sql.NewMysqlParser())
5858
b.SetParserOptions(sql.NewSqlModeFromString(procedure.SqlMode).ParserOptions())
59-
parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, false)
59+
parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, nil, false)
6060
if err != nil {
6161
procToRegister = &plan.Procedure{
6262
CreateProcedureString: procedure.CreateStatement,
@@ -300,7 +300,7 @@ func applyProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
300300
b.ProcCtx().AsOf = asOf
301301
}
302302
b.ProcCtx().DbName = call.Database().Name()
303-
parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, false)
303+
parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, nil, false)
304304
if err != nil {
305305
return nil, transform.SameTree, err
306306
}

sql/analyzer/triggers.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ func applyTriggers(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope,
204204
var parsedTrigger sql.Node
205205
sqlMode := sql.NewSqlModeFromString(trigger.SqlMode)
206206
b.SetParserOptions(sqlMode.ParserOptions())
207-
parsedTrigger, _, _, _, err = b.Parse(trigger.CreateStatement, false)
207+
parsedTrigger, _, _, _, err = b.Parse(trigger.CreateStatement, nil, false)
208208
b.Reset()
209209
if err != nil {
210210
return nil, transform.SameTree, err
@@ -225,7 +225,7 @@ func applyTriggers(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope,
225225
// first pass allows unresolved before we know whether trigger is relevant
226226
// TODO store destination table name with trigger, so we don't have to do parse twice
227227
b.TriggerCtx().Call = true
228-
parsedTrigger, _, _, _, err = b.Parse(trigger.CreateStatement, false)
228+
parsedTrigger, _, _, _, err = b.Parse(trigger.CreateStatement, nil, false)
229229
b.TriggerCtx().Call = false
230230
b.Reset()
231231
if err != nil {

0 commit comments

Comments
 (0)