@@ -40,7 +40,9 @@ import (
4040 "github.com/dolthub/go-mysql-server/internal/sockstate"
4141 "github.com/dolthub/go-mysql-server/sql"
4242 "github.com/dolthub/go-mysql-server/sql/analyzer"
43+ "github.com/dolthub/go-mysql-server/sql/iters"
4344 "github.com/dolthub/go-mysql-server/sql/plan"
45+ "github.com/dolthub/go-mysql-server/sql/rowexec"
4446 "github.com/dolthub/go-mysql-server/sql/types"
4547)
4648
@@ -218,7 +220,7 @@ func (h *Handler) ComExecuteBound(ctx context.Context, conn *mysql.Conn, query s
218220func (h * Handler ) ComStmtExecute (ctx context.Context , c * mysql.Conn , prepare * mysql.PrepareData , callback func (* sqltypes.Result ) error ) error {
219221 _ , err := h .errorWrappedDoQuery (ctx , c , prepare .PrepareStmt , nil , MultiStmtModeOff , prepare .BindVars , func (res * sqltypes.Result , more bool ) error {
220222 return callback (res )
221- }, nil )
223+ }, & sql. QueryFlags {} )
222224 return err
223225}
224226
@@ -295,7 +297,7 @@ func (h *Handler) ComMultiQuery(
295297 query string ,
296298 callback mysql.ResultSpoolFn ,
297299) (string , error ) {
298- return h .errorWrappedDoQuery (ctx , c , query , nil , MultiStmtModeOn , nil , callback , nil )
300+ return h .errorWrappedDoQuery (ctx , c , query , nil , MultiStmtModeOn , nil , callback , & sql. QueryFlags {} )
299301}
300302
301303// ComQuery executes a SQL query on the SQLe engine.
@@ -305,7 +307,7 @@ func (h *Handler) ComQuery(
305307 query string ,
306308 callback mysql.ResultSpoolFn ,
307309) error {
308- _ , err := h .errorWrappedDoQuery (ctx , c , query , nil , MultiStmtModeOff , nil , callback , nil )
310+ _ , err := h .errorWrappedDoQuery (ctx , c , query , nil , MultiStmtModeOff , nil , callback , & sql. QueryFlags {} )
309311 return err
310312}
311313
@@ -317,7 +319,7 @@ func (h *Handler) ComParsedQuery(
317319 parsed sqlparser.Statement ,
318320 callback mysql.ResultSpoolFn ,
319321) error {
320- _ , err := h .errorWrappedDoQuery (ctx , c , query , parsed , MultiStmtModeOff , nil , callback , nil )
322+ _ , err := h .errorWrappedDoQuery (ctx , c , query , parsed , MultiStmtModeOff , nil , callback , & sql. QueryFlags {} )
321323 return err
322324}
323325
@@ -424,6 +426,7 @@ func (h *Handler) doQuery(
424426 }
425427 }()
426428
429+ qFlags .Set (sql .QFlagDeferProjections )
427430 schema , rowIter , qFlags , err := queryExec (sqlCtx , query , parsed , analyzedPlan , bindings , qFlags )
428431 if err != nil {
429432 sqlCtx .GetLogger ().WithError (err ).Warn ("error running query" )
@@ -511,6 +514,37 @@ func resultForEmptyIter(ctx *sql.Context, iter sql.RowIter, resultFields []*quer
511514 return & sqltypes.Result {Fields : resultFields }, nil
512515}
513516
517+ // GetDeferredProjections looks for a top-level deferred projection, retrieves its projections, and removes it from the
518+ // iterator tree.
519+ func GetDeferredProjections (iter sql.RowIter ) (sql.RowIter , []sql.Expression ) {
520+ switch i := iter .(type ) {
521+ case * rowexec.ExprCloserIter :
522+ _ , projs := GetDeferredProjections (i .GetIter ())
523+ return i , projs
524+ case * plan.TrackedRowIter :
525+ _ , projs := GetDeferredProjections (i .GetIter ())
526+ return i , projs
527+ case * rowexec.TransactionCommittingIter :
528+ newChild , projs := GetDeferredProjections (i .GetIter ())
529+ if projs != nil {
530+ i .WithChildIter (newChild )
531+ }
532+ return i , projs
533+ case * iters.LimitIter :
534+ newChild , projs := GetDeferredProjections (i .ChildIter )
535+ if projs != nil {
536+ i .ChildIter = newChild
537+ }
538+ return i , projs
539+ case * rowexec.ProjectIter :
540+ if i .CanDefer () {
541+ return i .GetChildIter (), i .GetProjections ()
542+ }
543+ return i , nil
544+ }
545+ return iter , nil
546+ }
547+
514548// resultForMax1RowIter ensures that an empty iterator returns at most one row
515549func resultForMax1RowIter (ctx * sql.Context , schema sql.Schema , iter sql.RowIter , resultFields []* querypb.Field ) (* sqltypes.Result , error ) {
516550 defer trace .StartRegion (ctx , "Handler.resultForMax1RowIter" ).End ()
@@ -527,8 +561,7 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter,
527561 if err := iter .Close (ctx ); err != nil {
528562 return nil , err
529563 }
530-
531- outputRow , err := rowToSQL (ctx , schema , row )
564+ outputRow , err := RowToSQL (ctx , schema , row , nil )
532565 if err != nil {
533566 return nil , err
534567 }
@@ -558,16 +591,11 @@ func (h *Handler) resultForDefaultIter(
558591 }
559592 }
560593
561- pollCtx , cancelF := ctx .NewSubContext ()
562- eg .Go (func () error {
563- defer pan2err ()
564- return h .pollForClosedConnection (pollCtx , c )
565- })
566-
567594 wg := sync.WaitGroup {}
568595 wg .Add (2 )
569596
570597 // Read rows off the row iterator and send them to the row channel.
598+ iter , projs := GetDeferredProjections (iter )
571599 var rowChan = make (chan sql.Row , 512 )
572600 eg .Go (func () error {
573601 defer pan2err ()
@@ -594,6 +622,12 @@ func (h *Handler) resultForDefaultIter(
594622 }
595623 })
596624
625+ pollCtx , cancelF := ctx .NewSubContext ()
626+ eg .Go (func () error {
627+ defer pan2err ()
628+ return h .pollForClosedConnection (pollCtx , c )
629+ })
630+
597631 // Default waitTime is one minute if there is no timeout configured, in which case
598632 // it will loop to iterate again unless the socket died by the OS timeout or other problems.
599633 // If there is a timeout, it will be enforced to ensure that Vitess has a chance to
@@ -639,7 +673,7 @@ func (h *Handler) resultForDefaultIter(
639673 continue
640674 }
641675
642- outputRow , err := rowToSQL (ctx , schema , row )
676+ outputRow , err := RowToSQL (ctx , schema , row , projs )
643677 if err != nil {
644678 return err
645679 }
@@ -648,6 +682,7 @@ func (h *Handler) resultForDefaultIter(
648682 r .Rows = append (r .Rows , outputRow )
649683 r .RowsAffected ++
650684 case <- timer .C :
685+ // TODO: timer should probably go in its own thread, as rowChan is blocking
651686 if h .readTimeout != 0 {
652687 // Cancel and return so Vitess can call the CloseConnection callback
653688 ctx .GetLogger ().Tracef ("connection timeout" )
@@ -901,25 +936,43 @@ func updateMaxUsedConnectionsStatusVariable() {
901936 }()
902937}
903938
904- func rowToSQL (ctx * sql.Context , s sql.Schema , row sql.Row ) ([]sqltypes.Value , error ) {
905- o := make ([]sqltypes.Value , len (row ))
939+ func RowToSQL (ctx * sql.Context , sch sql.Schema , row sql.Row , projs []sql.Expression ) ([]sqltypes.Value , error ) {
906940 // need to make sure the schema is not null as some plan schema is defined as null (e.g. IfElseBlock)
907- if len (s ) == 0 {
908- return o , nil
941+ if len (sch ) == 0 {
942+ return []sqltypes. Value {} , nil
909943 }
910- var err error
911- for i , v := range row {
912- if v == nil {
913- o [i ] = sqltypes .NULL
944+
945+ outVals := make ([]sqltypes.Value , len (sch ))
946+ if len (projs ) == 0 {
947+ for i , col := range sch {
948+ if row [i ] == nil {
949+ outVals [i ] = sqltypes .NULL
950+ continue
951+ }
952+ var err error
953+ outVals [i ], err = col .Type .SQL (ctx , nil , row [i ])
954+ if err != nil {
955+ return nil , err
956+ }
957+ }
958+ return outVals , nil
959+ }
960+
961+ for i , col := range sch {
962+ field , err := projs [i ].Eval (ctx , row )
963+ if err != nil {
964+ return nil , err
965+ }
966+ if field == nil {
967+ outVals [i ] = sqltypes .NULL
914968 continue
915969 }
916- o [i ], err = s [ i ] .Type .SQL (ctx , nil , v )
970+ outVals [i ], err = col .Type .SQL (ctx , nil , field )
917971 if err != nil {
918972 return nil , err
919973 }
920974 }
921-
922- return o , nil
975+ return outVals , nil
923976}
924977
925978func schemaToFields (ctx * sql.Context , s sql.Schema ) []* querypb.Field {
0 commit comments