@@ -17,6 +17,7 @@ package server
1717import (
1818 "context"
1919 "encoding/base64"
20+ goerrors "errors"
2021 "fmt"
2122 "io"
2223 "net"
@@ -609,31 +610,32 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter,
609610
610611// resultForDefaultIter reads batches of rows from the iterator
611612// and writes results into the callback function.
612- func (h * Handler ) resultForDefaultIter (ctx * sql.Context , c * mysql.Conn , schema sql.Schema , iter sql.RowIter , callback func (* sqltypes.Result , bool ) error , resultFields []* querypb.Field , more bool , buf * sql.ByteBuffer ) (r * sqltypes.Result , processedAtLeastOneBatch bool , returnErr error ) {
613+ func (h * Handler ) resultForDefaultIter (ctx * sql.Context , c * mysql.Conn , schema sql.Schema , iter sql.RowIter , callback func (* sqltypes.Result , bool ) error , resultFields []* querypb.Field , more bool , buf * sql.ByteBuffer ) (* sqltypes.Result , bool , error ) {
613614 defer trace .StartRegion (ctx , "Handler.resultForDefaultIter" ).End ()
614615
615616 eg , ctx := ctx .NewErrgroup ()
616-
617- pan2err := func () {
617+ pan2err := func (err * error ) {
618618 if recoveredPanic := recover (); recoveredPanic != nil {
619- returnErr = fmt .Errorf ("handler caught panic: %v" , recoveredPanic )
619+ * err = goerrors . Join ( * err , fmt .Errorf ("handler caught panic: %v" , recoveredPanic ) )
620620 }
621621 }
622-
623622 wg := sync.WaitGroup {}
624623 wg .Add (2 )
625624
625+ var r * sqltypes.Result
626+ var processedAtLeastOneBatch bool
627+
626628 // Read rows off the row iterator and send them to the row channel.
627629 iter , projs := GetDeferredProjections (iter )
628630 var rowChan = make (chan sql.Row , 512 )
629- eg .Go (func () error {
630- defer pan2err ()
631+ eg .Go (func () ( err error ) {
632+ defer pan2err (& err )
631633 defer wg .Done ()
632634 defer close (rowChan )
633635 for {
634636 select {
635637 case <- ctx .Done ():
636- return nil
638+ return context . Cause ( ctx )
637639 default :
638640 row , err := iter .Next (ctx )
639641 if err == io .EOF {
@@ -651,9 +653,12 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
651653 }
652654 })
653655
656+ // TODO: poll for closed connections should obviously also run even if
657+ // we're doing something with an OK result or a single row result, etc.
658+ // This should be in the caller.
654659 pollCtx , cancelF := ctx .NewSubContext ()
655- eg .Go (func () error {
656- defer pan2err ()
660+ eg .Go (func () ( err error ) {
661+ defer pan2err (& err )
657662 return h .pollForClosedConnection (pollCtx , c )
658663 })
659664
@@ -676,8 +681,8 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
676681
677682 // Reads rows from the channel, converts them to wire format,
678683 // and calls |callback| to give them to vitess.
679- eg .Go (func () error {
680- defer pan2err ()
684+ eg .Go (func () ( err error ) {
685+ defer pan2err (& err )
681686 defer cancelF ()
682687 defer wg .Done ()
683688 for {
@@ -695,7 +700,7 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
695700
696701 select {
697702 case <- ctx .Done ():
698- return nil
703+ return context . Cause ( ctx )
699704 case row , ok := <- rowChan :
700705 if ! ok {
701706 return nil
@@ -716,6 +721,9 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
716721 ctx .GetLogger ().Tracef ("spooling result row %s" , outputRow )
717722 r .Rows = append (r .Rows , outputRow )
718723 r .RowsAffected ++
724+ if ! timer .Stop () {
725+ <- timer .C
726+ }
719727 case <- timer .C :
720728 // TODO: timer should probably go in its own thread, as rowChan is blocking
721729 if h .readTimeout != 0 {
@@ -724,17 +732,14 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
724732 return ErrRowTimeout .New ()
725733 }
726734 }
727- if ! timer .Stop () {
728- <- timer .C
729- }
730735 timer .Reset (waitTime )
731736 }
732737 })
733738
734739 // Close() kills this PID in the process list,
735740 // wait until all rows have be sent over the wire
736- eg .Go (func () error {
737- defer pan2err ()
741+ eg .Go (func () ( err error ) {
742+ defer pan2err (& err )
738743 wg .Wait ()
739744 return iter .Close (ctx )
740745 })
@@ -745,9 +750,9 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
745750 if verboseErrorLogging {
746751 fmt .Printf ("Err: %+v" , err )
747752 }
748- returnErr = err
753+ return nil , false , err
749754 }
750- return
755+ return r , processedAtLeastOneBatch , nil
751756}
752757
753758// See https://dev.mysql.com/doc/internals/en/status-flags.html
0 commit comments