@@ -17,6 +17,7 @@ package server
17
17
import (
18
18
"context"
19
19
"encoding/base64"
20
+ goerrors "errors"
20
21
"fmt"
21
22
"io"
22
23
"net"
@@ -609,31 +610,32 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter,
609
610
610
611
// resultForDefaultIter reads batches of rows from the iterator
611
612
// and writes results into the callback function.
612
- func (h * Handler ) resultForDefaultIter (ctx * sql.Context , c * mysql.Conn , schema sql.Schema , iter sql.RowIter , callback func (* sqltypes.Result , bool ) error , resultFields []* querypb.Field , more bool , buf * sql.ByteBuffer ) (r * sqltypes.Result , processedAtLeastOneBatch bool , returnErr error ) {
613
+ func (h * Handler ) resultForDefaultIter (ctx * sql.Context , c * mysql.Conn , schema sql.Schema , iter sql.RowIter , callback func (* sqltypes.Result , bool ) error , resultFields []* querypb.Field , more bool , buf * sql.ByteBuffer ) (* sqltypes.Result , bool , error ) {
613
614
defer trace .StartRegion (ctx , "Handler.resultForDefaultIter" ).End ()
614
615
615
616
eg , ctx := ctx .NewErrgroup ()
616
-
617
- pan2err := func () {
617
+ pan2err := func (err * error ) {
618
618
if recoveredPanic := recover (); recoveredPanic != nil {
619
- returnErr = fmt .Errorf ("handler caught panic: %v" , recoveredPanic )
619
+ * err = goerrors . Join ( * err , fmt .Errorf ("handler caught panic: %v" , recoveredPanic ) )
620
620
}
621
621
}
622
-
623
622
wg := sync.WaitGroup {}
624
623
wg .Add (2 )
625
624
625
+ var r * sqltypes.Result
626
+ var processedAtLeastOneBatch bool
627
+
626
628
// Read rows off the row iterator and send them to the row channel.
627
629
iter , projs := GetDeferredProjections (iter )
628
630
var rowChan = make (chan sql.Row , 512 )
629
- eg .Go (func () error {
630
- defer pan2err ()
631
+ eg .Go (func () ( err error ) {
632
+ defer pan2err (& err )
631
633
defer wg .Done ()
632
634
defer close (rowChan )
633
635
for {
634
636
select {
635
637
case <- ctx .Done ():
636
- return nil
638
+ return context . Cause ( ctx )
637
639
default :
638
640
row , err := iter .Next (ctx )
639
641
if err == io .EOF {
@@ -651,9 +653,12 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
651
653
}
652
654
})
653
655
656
+ // TODO: poll for closed connections should obviously also run even if
657
+ // we're doing something with an OK result or a single row result, etc.
658
+ // This should be in the caller.
654
659
pollCtx , cancelF := ctx .NewSubContext ()
655
- eg .Go (func () error {
656
- defer pan2err ()
660
+ eg .Go (func () ( err error ) {
661
+ defer pan2err (& err )
657
662
return h .pollForClosedConnection (pollCtx , c )
658
663
})
659
664
@@ -676,8 +681,8 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
676
681
677
682
// Reads rows from the channel, converts them to wire format,
678
683
// and calls |callback| to give them to vitess.
679
- eg .Go (func () error {
680
- defer pan2err ()
684
+ eg .Go (func () ( err error ) {
685
+ defer pan2err (& err )
681
686
defer cancelF ()
682
687
defer wg .Done ()
683
688
for {
@@ -695,7 +700,7 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
695
700
696
701
select {
697
702
case <- ctx .Done ():
698
- return nil
703
+ return context . Cause ( ctx )
699
704
case row , ok := <- rowChan :
700
705
if ! ok {
701
706
return nil
@@ -716,6 +721,9 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
716
721
ctx .GetLogger ().Tracef ("spooling result row %s" , outputRow )
717
722
r .Rows = append (r .Rows , outputRow )
718
723
r .RowsAffected ++
724
+ if ! timer .Stop () {
725
+ <- timer .C
726
+ }
719
727
case <- timer .C :
720
728
// TODO: timer should probably go in its own thread, as rowChan is blocking
721
729
if h .readTimeout != 0 {
@@ -724,17 +732,14 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
724
732
return ErrRowTimeout .New ()
725
733
}
726
734
}
727
- if ! timer .Stop () {
728
- <- timer .C
729
- }
730
735
timer .Reset (waitTime )
731
736
}
732
737
})
733
738
734
739
// Close() kills this PID in the process list,
735
740
// wait until all rows have be sent over the wire
736
- eg .Go (func () error {
737
- defer pan2err ()
741
+ eg .Go (func () ( err error ) {
742
+ defer pan2err (& err )
738
743
wg .Wait ()
739
744
return iter .Close (ctx )
740
745
})
@@ -745,9 +750,9 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
745
750
if verboseErrorLogging {
746
751
fmt .Printf ("Err: %+v" , err )
747
752
}
748
- returnErr = err
753
+ return nil , false , err
749
754
}
750
- return
755
+ return r , processedAtLeastOneBatch , nil
751
756
}
752
757
753
758
// See https://dev.mysql.com/doc/internals/en/status-flags.html
0 commit comments