Skip to content

Commit 8fb4a87

Browse files
committed
server/handler.go: Improve some edge cases and error handling in resultForDefaultIter.
1 parent f73a318 commit 8fb4a87

File tree

1 file changed

+25
-20
lines changed

1 file changed

+25
-20
lines changed

server/handler.go

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package server
1717
import (
1818
"context"
1919
"encoding/base64"
20+
goerrors "errors"
2021
"fmt"
2122
"io"
2223
"net"
@@ -609,31 +610,32 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter,
609610

610611
// resultForDefaultIter reads batches of rows from the iterator
611612
// and writes results into the callback function.
612-
func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema sql.Schema, iter sql.RowIter, callback func(*sqltypes.Result, bool) error, resultFields []*querypb.Field, more bool, buf *sql.ByteBuffer) (r *sqltypes.Result, processedAtLeastOneBatch bool, returnErr error) {
613+
func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema sql.Schema, iter sql.RowIter, callback func(*sqltypes.Result, bool) error, resultFields []*querypb.Field, more bool, buf *sql.ByteBuffer) (*sqltypes.Result, bool, error) {
613614
defer trace.StartRegion(ctx, "Handler.resultForDefaultIter").End()
614615

615616
eg, ctx := ctx.NewErrgroup()
616-
617-
pan2err := func() {
617+
pan2err := func(err *error) {
618618
if recoveredPanic := recover(); recoveredPanic != nil {
619-
returnErr = fmt.Errorf("handler caught panic: %v", recoveredPanic)
619+
*err = goerrors.Join(*err, fmt.Errorf("handler caught panic: %v", recoveredPanic))
620620
}
621621
}
622-
623622
wg := sync.WaitGroup{}
624623
wg.Add(2)
625624

625+
var r *sqltypes.Result
626+
var processedAtLeastOneBatch bool
627+
626628
// Read rows off the row iterator and send them to the row channel.
627629
iter, projs := GetDeferredProjections(iter)
628630
var rowChan = make(chan sql.Row, 512)
629-
eg.Go(func() error {
630-
defer pan2err()
631+
eg.Go(func() (err error) {
632+
defer pan2err(&err)
631633
defer wg.Done()
632634
defer close(rowChan)
633635
for {
634636
select {
635637
case <-ctx.Done():
636-
return nil
638+
return context.Cause(ctx)
637639
default:
638640
row, err := iter.Next(ctx)
639641
if err == io.EOF {
@@ -651,9 +653,12 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
651653
}
652654
})
653655

656+
// TODO: poll for closed connections should obviously also run even if
657+
// we're doing something with an OK result or a single row result, etc.
658+
// This should be in the caller.
654659
pollCtx, cancelF := ctx.NewSubContext()
655-
eg.Go(func() error {
656-
defer pan2err()
660+
eg.Go(func() (err error) {
661+
defer pan2err(&err)
657662
return h.pollForClosedConnection(pollCtx, c)
658663
})
659664

@@ -676,8 +681,8 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
676681

677682
// Reads rows from the channel, converts them to wire format,
678683
// and calls |callback| to give them to vitess.
679-
eg.Go(func() error {
680-
defer pan2err()
684+
eg.Go(func() (err error) {
685+
defer pan2err(&err)
681686
defer cancelF()
682687
defer wg.Done()
683688
for {
@@ -695,7 +700,7 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
695700

696701
select {
697702
case <-ctx.Done():
698-
return nil
703+
return context.Cause(ctx)
699704
case row, ok := <-rowChan:
700705
if !ok {
701706
return nil
@@ -716,6 +721,9 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
716721
ctx.GetLogger().Tracef("spooling result row %s", outputRow)
717722
r.Rows = append(r.Rows, outputRow)
718723
r.RowsAffected++
724+
if !timer.Stop() {
725+
<-timer.C
726+
}
719727
case <-timer.C:
720728
// TODO: timer should probably go in its own thread, as rowChan is blocking
721729
if h.readTimeout != 0 {
@@ -724,17 +732,14 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
724732
return ErrRowTimeout.New()
725733
}
726734
}
727-
if !timer.Stop() {
728-
<-timer.C
729-
}
730735
timer.Reset(waitTime)
731736
}
732737
})
733738

734739
// Close() kills this PID in the process list,
735740
// wait until all rows have be sent over the wire
736-
eg.Go(func() error {
737-
defer pan2err()
741+
eg.Go(func() (err error) {
742+
defer pan2err(&err)
738743
wg.Wait()
739744
return iter.Close(ctx)
740745
})
@@ -745,9 +750,9 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
745750
if verboseErrorLogging {
746751
fmt.Printf("Err: %+v", err)
747752
}
748-
returnErr = err
753+
return nil, false, err
749754
}
750-
return
755+
return r, processedAtLeastOneBatch, nil
751756
}
752757

753758
// See https://dev.mysql.com/doc/internals/en/status-flags.html

0 commit comments

Comments
 (0)