@@ -496,7 +496,7 @@ func (h *Handler) doQuery(
496496 } else if analyzer .FlagIsSet (qFlags , sql .QFlagMax1Row ) {
497497 r , err = resultForMax1RowIter (sqlCtx , schema , rowIter , resultFields , buf )
498498 } else if ri2 , ok := rowIter .(sql.RowIter2 ); ok && ri2 .IsRowIter2 (sqlCtx ) {
499- r , err = h .resultForDefaultIter2 (sqlCtx , ri2 , resultFields , callback , more )
499+ r , processedAtLeastOneBatch , err = h .resultForDefaultIter2 (sqlCtx , c , ri2 , resultFields , callback , more )
500500 } else {
501501 r , processedAtLeastOneBatch , err = h .resultForDefaultIter (sqlCtx , c , schema , rowIter , callback , resultFields , more , buf )
502502 }
@@ -770,30 +770,107 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
770770 return r , processedAtLeastOneBatch , nil
771771}
772772
773- func (h * Handler ) resultForDefaultIter2 (ctx * sql.Context , iter sql.RowIter2 , resultFields []* querypb.Field , callback func (* sqltypes.Result , bool ) error , more bool ) (* sqltypes.Result , error ) {
774- res := & sqltypes.Result {Fields : resultFields }
775- for {
776- if res .RowsAffected == rowsBatch {
777- if err := callback (res , more ); err != nil {
778- return nil , err
779- }
780- res = nil
781- }
782- row , err := iter .Next2 (ctx )
783- if err == io .EOF {
784- return res , nil
773+ func (h * Handler ) resultForDefaultIter2 (ctx * sql.Context , c * mysql.Conn , iter sql.RowIter2 , resultFields []* querypb.Field , callback func (* sqltypes.Result , bool ) error , more bool ) (* sqltypes.Result , bool , error ) {
774+ defer trace .StartRegion (ctx , "Handler.resultForDefaultIter" ).End ()
775+
776+ eg , ctx := ctx .NewErrgroup ()
777+ pan2err := func (err * error ) {
778+ if recoveredPanic := recover (); recoveredPanic != nil {
779+ stack := debug .Stack ()
780+ wrappedErr := fmt .Errorf ("handler caught panic: %v\n %s" , recoveredPanic , stack )
781+ * err = goerrors .Join (* err , wrappedErr )
785782 }
786- if err != nil {
787- return nil , err
783+ }
784+
785+ // TODO: poll for closed connections should obviously also run even if
786+ // we're doing something with an OK result or a single row result, etc.
787+ // This should be in the caller.
788+ pollCtx , cancelF := ctx .NewSubContext ()
789+ eg .Go (func () (err error ) {
790+ defer pan2err (& err )
791+ return h .pollForClosedConnection (pollCtx , c )
792+ })
793+
794+ // Default waitTime is one minute if there is no timeout configured, in which case
795+ // it will loop to iterate again unless the socket died by the OS timeout or other problems.
796+ // If there is a timeout, it will be enforced to ensure that Vitess has a chance to
797+ // call Handler.CloseConnection()
798+ waitTime := 1 * time .Minute
799+ if h .readTimeout > 0 {
800+ waitTime = h .readTimeout
801+ }
802+ timer := time .NewTimer (waitTime )
803+ defer timer .Stop ()
804+
805+ wg := sync.WaitGroup {}
806+ wg .Add (1 )
807+
808+ var res * sqltypes.Result
809+ var processedAtLeastOneBatch bool
810+ eg .Go (func () (err error ) {
811+ defer pan2err (& err )
812+ defer cancelF ()
813+ defer wg .Done ()
814+ for {
815+ if res == nil {
816+ res = & sqltypes.Result {Fields : resultFields }
817+ }
818+ if res .RowsAffected == rowsBatch {
819+ if err := callback (res , more ); err != nil {
820+ return err
821+ }
822+ res = nil
823+ processedAtLeastOneBatch = true
824+ continue
825+ }
826+
827+ select {
828+ case <- ctx .Done ():
829+ return context .Cause (ctx )
830+ case <- timer .C :
831+ // TODO: timer should probably go in its own thread, as rowChan is blocking
832+ if h .readTimeout != 0 {
833+ // Cancel and return so Vitess can call the CloseConnection callback
834+ ctx .GetLogger ().Tracef ("connection timeout" )
835+ return ErrRowTimeout .New ()
836+ }
837+ default :
838+ row , err := iter .Next2 (ctx )
839+ if err == io .EOF {
840+ return nil
841+ }
842+ if err != nil {
843+ return err
844+ }
845+ outRow := make ([]sqltypes.Value , len (row ))
846+ for i := range row {
847+ outRow [i ] = sqltypes .MakeTrusted (row [i ].Typ , row [i ].Val )
848+ }
849+ res .Rows = append (res .Rows , outRow )
850+ res .RowsAffected ++
851+ }
852+ timer .Reset (waitTime )
788853 }
854+ })
789855
790- outRow := make ([]sqltypes.Value , len (res .Rows ))
791- for i := range row {
792- outRow [i ] = sqltypes .MakeTrusted (row [i ].Typ , row [i ].Val )
856+ // Close() kills this PID in the process list,
857+ // wait until all rows have be sent over the wire
858+ eg .Go (func () (err error ) {
859+ defer pan2err (& err )
860+ wg .Wait ()
861+ return iter .Close (ctx )
862+ })
863+
864+ err := eg .Wait ()
865+ if err != nil {
866+ ctx .GetLogger ().WithError (err ).Warn ("error running query" )
867+ if verboseErrorLogging {
868+ fmt .Printf ("Err: %+v" , err )
793869 }
794- res .Rows = append (res .Rows , outRow )
795- res .RowsAffected ++
870+ return nil , false , err
796871 }
872+
873+ return res , processedAtLeastOneBatch , nil
797874}
798875
799876// See https://dev.mysql.com/doc/internals/en/status-flags.html
0 commit comments