Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 62 additions & 47 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -480,12 +480,12 @@ func (h *Handler) doQuery(
// create result before goroutines to avoid |ctx| racing
resultFields := schemaToFields(sqlCtx, schema)
var r *sqltypes.Result
var buf *sql.ByteBuffer
var processedAtLeastOneBatch bool

buf := sql.ByteBufPool.Get().(*sql.ByteBuffer)
defer func() {
buf.Reset()
sql.ByteBufPool.Put(buf)
if buf != nil {
sql.ByteBufPool.Put(buf)
}
}()

// zero/single return schema use spooling shortcut
Expand All @@ -496,9 +496,9 @@ func (h *Handler) doQuery(
} else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) {
r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf)
} else if vr, ok := rowIter.(sql.ValueRowIter); ok && vr.IsValueRowIter(sqlCtx) {
r, processedAtLeastOneBatch, err = h.resultForValueRowIter(sqlCtx, c, schema, vr, resultFields, buf, callback, more)
r, buf, processedAtLeastOneBatch, err = h.resultForValueRowIter(sqlCtx, c, schema, vr, resultFields, callback, more)
} else {
r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf)
r, buf, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more)
}
if err != nil {
return remainder, err
Expand Down Expand Up @@ -527,6 +527,8 @@ func (h *Handler) doQuery(
return remainder, nil
}

// TODO: the very last buffer needs to be released

return remainder, callback(r, more)
}

Expand Down Expand Up @@ -598,7 +600,8 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter,
row, err := iter.Next(ctx)
if err == io.EOF {
return &sqltypes.Result{Fields: resultFields}, nil
} else if err != nil {
}
if err != nil {
return nil, err
}

Expand All @@ -618,7 +621,7 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter,

// resultForDefaultIter reads batches of rows from the iterator
// and writes results into the callback function.
func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema sql.Schema, iter sql.RowIter, callback func(*sqltypes.Result, bool) error, resultFields []*querypb.Field, more bool, buf *sql.ByteBuffer) (*sqltypes.Result, bool, error) {
func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema sql.Schema, iter sql.RowIter, callback func(*sqltypes.Result, bool) error, resultFields []*querypb.Field, more bool) (*sqltypes.Result, *sql.ByteBuffer, bool, error) {
defer trace.StartRegion(ctx, "Handler.resultForDefaultIter").End()

eg, ctx := ctx.NewErrgroup()
Expand Down Expand Up @@ -650,17 +653,6 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
timer := time.NewTimer(waitTime)
defer timer.Stop()

// Wrap the callback to include a BytesBuffer.Reset() for non-cursor requests, to
// clean out rows that have already been spooled.
// A server-side cursor allows the caller to fetch results cached on the server-side,
// so if a cursor exists, we can't release the buffer memory yet.
if c.StatusFlags&uint16(mysql.ServerCursorExists) != 0 {
callback = func(r *sqltypes.Result, more bool) error {
defer buf.Reset()
return callback(r, more)
}
}

iter, projs := GetDeferredProjections(iter)

wg := sync.WaitGroup{}
Expand Down Expand Up @@ -694,8 +686,13 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
})

// Drain rows from rowChan, convert to wire format, and send to resChan
var resChan = make(chan *sqltypes.Result, 4)
type bufferedResult struct {
res *sqltypes.Result
buf *sql.ByteBuffer
}
var resChan = make(chan bufferedResult, 4)
var res *sqltypes.Result
var buf *sql.ByteBuffer
eg.Go(func() (err error) {
defer pan2err(&err)
defer wg.Done()
Expand All @@ -707,6 +704,8 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
Fields: resultFields,
Rows: make([][]sqltypes.Value, 0, rowsBatch),
}
buf = sql.ByteBufPool.Get().(*sql.ByteBuffer)
buf.Reset()
}

select {
Expand Down Expand Up @@ -746,8 +745,9 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
select {
case <-ctx.Done():
return context.Cause(ctx)
case resChan <- res:
case resChan <- bufferedResult{res: res, buf: buf}:
res = nil
buf = nil // TODO: not sure if this is necessary to prevent double Put()
}
}
}
Expand All @@ -756,7 +756,7 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
}
})

// Drain sqltypes.Result from resChan and call callback (send to client and potentially reset buffer)
// Drain sqltypes.Result from resChan and call callback (send to client and reset buffer)
var processedAtLeastOneBatch bool
eg.Go(func() (err error) {
defer pan2err(&err)
Expand All @@ -766,15 +766,21 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
select {
case <-ctx.Done():
return context.Cause(ctx)
case r, ok := <-resChan:
case bufRes, ok := <-resChan:
if !ok {
return nil
}
processedAtLeastOneBatch = true
err = callback(r, more)
err = callback(bufRes.res, more)
if err != nil {
return err
}
processedAtLeastOneBatch = true
// A server-side cursor allows the caller to fetch results cached on the server-side,
// so if a cursor exists, we can't release the buffer memory yet.
// TODO: In the case of a cursor, we are leaking memory
if c.StatusFlags&uint16(mysql.ServerCursorExists) != 0 {
sql.ByteBufPool.Put(bufRes.buf)
}
}
}
})
Expand All @@ -793,12 +799,12 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
if verboseErrorLogging {
fmt.Printf("Err: %+v", err)
}
return nil, false, err
return nil, nil, false, err
}
return res, processedAtLeastOneBatch, nil
return res, buf, processedAtLeastOneBatch, nil
}

func (h *Handler) resultForValueRowIter(ctx *sql.Context, c *mysql.Conn, schema sql.Schema, iter sql.ValueRowIter, resultFields []*querypb.Field, buf *sql.ByteBuffer, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, bool, error) {
func (h *Handler) resultForValueRowIter(ctx *sql.Context, c *mysql.Conn, schema sql.Schema, iter sql.ValueRowIter, resultFields []*querypb.Field, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, *sql.ByteBuffer, bool, error) {
defer trace.StartRegion(ctx, "Handler.resultForValueRowIter").End()

eg, ctx := ctx.NewErrgroup()
Expand Down Expand Up @@ -829,17 +835,6 @@ func (h *Handler) resultForValueRowIter(ctx *sql.Context, c *mysql.Conn, schema
timer := time.NewTimer(waitTime)
defer timer.Stop()

// Wrap the callback to include a BytesBuffer.Reset() for non-cursor requests, to
// clean out rows that have already been spooled.
// A server-side cursor allows the caller to fetch results cached on the server-side,
// so if a cursor exists, we can't release the buffer memory yet.
if c.StatusFlags&uint16(mysql.ServerCursorExists) != 0 {
callback = func(r *sqltypes.Result, more bool) error {
defer buf.Reset()
return callback(r, more)
}
}

wg := sync.WaitGroup{}
wg.Add(3)

Expand Down Expand Up @@ -871,8 +866,13 @@ func (h *Handler) resultForValueRowIter(ctx *sql.Context, c *mysql.Conn, schema
})

// Drain rows from rowChan, convert to wire format, and send to resChan
var resChan = make(chan *sqltypes.Result, 4)
type bufferedResult struct {
res *sqltypes.Result
buf *sql.ByteBuffer
}
var resChan = make(chan bufferedResult, 4)
var res *sqltypes.Result
var buf *sql.ByteBuffer
eg.Go(func() (err error) {
defer pan2err(&err)
defer close(resChan)
Expand All @@ -884,6 +884,8 @@ func (h *Handler) resultForValueRowIter(ctx *sql.Context, c *mysql.Conn, schema
Fields: resultFields,
Rows: make([][]sqltypes.Value, rowsBatch),
}
buf = sql.ByteBufPool.Get().(*sql.ByteBuffer)
buf.Reset()
}

select {
Expand Down Expand Up @@ -915,8 +917,9 @@ func (h *Handler) resultForValueRowIter(ctx *sql.Context, c *mysql.Conn, schema
select {
case <-ctx.Done():
return context.Cause(ctx)
case resChan <- res:
case resChan <- bufferedResult{res: res, buf: buf}:
res = nil
buf = nil
}
}
}
Expand All @@ -935,15 +938,21 @@ func (h *Handler) resultForValueRowIter(ctx *sql.Context, c *mysql.Conn, schema
select {
case <-ctx.Done():
return context.Cause(ctx)
case r, ok := <-resChan:
case bufRes, ok := <-resChan:
if !ok {
return nil
}
processedAtLeastOneBatch = true
err = callback(r, more)
err = callback(bufRes.res, more)
if err != nil {
return err
}
processedAtLeastOneBatch = true
// A server-side cursor allows the caller to fetch results cached on the server-side,
// so if a cursor exists, we can't release the buffer memory yet.
// TODO: In the case of a cursor, we are leaking memory
if c.StatusFlags&uint16(mysql.ServerCursorExists) != 0 {
sql.ByteBufPool.Put(bufRes.buf)
}
}
}
})
Expand All @@ -962,11 +971,13 @@ func (h *Handler) resultForValueRowIter(ctx *sql.Context, c *mysql.Conn, schema
if verboseErrorLogging {
fmt.Printf("Err: %+v", err)
}
return nil, false, err
return nil, nil, false, err
}

res.Rows = res.Rows[:res.RowsAffected]
return res, processedAtLeastOneBatch, err
if res != nil {
res.Rows = res.Rows[:res.RowsAffected]
}
return res, buf, processedAtLeastOneBatch, err
}

// See https://dev.mysql.com/doc/internals/en/status-flags.html
Expand Down Expand Up @@ -1194,6 +1205,10 @@ func toSqlHelper(ctx *sql.Context, typ sql.Type, buf *sql.ByteBuffer, val interf
if buf == nil {
return typ.SQL(ctx, nil, val)
}
// TODO: possible to predict max amount of space needed in backing array.
// Only number types are written to byte buffer due to strconv.Append...
// String types already create a new []byte, so it's better to not copy to backing array.

ret, err := typ.SQL(ctx, buf.Get(), val)
buf.Grow(ret.Len()) // TODO: shouldn't we check capacity beforehand?
return ret, err
Expand Down
8 changes: 4 additions & 4 deletions sql/byte_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,10 @@ func NewByteBuffer(initCap int) *ByteBuffer {
// they expect to be protected.
func (b *ByteBuffer) Grow(n int) {
newI := b.i
if b.i+n <= len(b.buf) {
if b.i+n < cap(b.buf) {
// Increment |b.i| if no alloc
newI += n
}
if b.i+n >= len(b.buf) {
} else {
// No more space, double.
// An external allocation doubled the cap using the size of
// the override object, which if used could lead to overall
Expand All @@ -59,7 +58,8 @@ func (b *ByteBuffer) Grow(n int) {
// here because the runtime only doubles based on slice
// length.
func (b *ByteBuffer) Double() {
buf := make([]byte, len(b.buf)*2)
// TODO: This wastes memory. The first half of b.buf won't be referenced by anything.
buf := make([]byte, cap(b.buf)*2)
copy(buf, b.buf)
b.buf = buf
}
Expand Down
23 changes: 23 additions & 0 deletions sql/byte_buffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package sql

import (
"fmt"
"strconv"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -70,3 +72,24 @@ func TestGrowByteBuffer(t *testing.T) {
require.Equal(t, 40, len(b.buf))
require.Equal(t, 0, b.i)
}

func TestByteBufferDoubling(t *testing.T) {
bb := NewByteBuffer(5)
fmt.Printf("bb.buf: %v, cap: %d\n", bb.buf, cap(bb.buf))
fmt.Printf("bb.i: %v\n", bb.i)

i0 := bb.Get()
i0 = strconv.AppendInt(i0, 12345, 10)
bb.Grow(len(i0))
fmt.Printf("i0: %v, cap: %d\n", i0, cap(i0))
fmt.Printf("bb.buf: %v, cap: %d\n", bb.buf, cap(bb.buf))
fmt.Printf("bb.i: %v\n", bb.i)

i5 := bb.Get()
i5 = strconv.AppendInt(i5, 678901, 10)
bb.Grow(len(i5))
fmt.Printf("i0: %v, cap: %d\n", i0, cap(i0))
fmt.Printf("i5: %v, cap: %d\n", i5, cap(i5))
fmt.Printf("bb.buf: %v, cap: %d\n", bb.buf, cap(bb.buf))
fmt.Printf("bb.i: %v\n", bb.i)
}
Loading