Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion enginetest/enginetests.go
Original file line number Diff line number Diff line change
Expand Up @@ -5743,6 +5743,7 @@ func TestTypesOverWire(t *testing.T, harness ClientHarness, sessionBuilder serve
require.NoError(t, err)
expectedRowSet := script.Results[queryIdx]
expectedRowIdx := 0
buf := sql.NewByteBuffer(1000)
var engineRow sql.Row
for engineRow, err = engineIter.Next(ctx); err == nil; engineRow, err = engineIter.Next(ctx) {
if !assert.True(t, r.Next()) {
Expand All @@ -5760,7 +5761,7 @@ func TestTypesOverWire(t *testing.T, harness ClientHarness, sessionBuilder serve
break
}
expectedEngineRow := make([]*string, len(engineRow))
row, err := server.RowToSQL(ctx, sch, engineRow, nil)
row, err := server.RowToSQL(ctx, sch, engineRow, nil, buf)
if !assert.NoError(t, err) {
break
}
Expand Down
47 changes: 30 additions & 17 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,15 +442,21 @@ func (h *Handler) doQuery(
var r *sqltypes.Result
var processedAtLeastOneBatch bool

buf := sql.ByteBufPool.Get().(*sql.ByteBuffer)
defer func() {
buf.Reset()
sql.ByteBufPool.Put(buf)
}()

// zero/single return schema use spooling shortcut
if types.IsOkResultSchema(schema) {
r, err = resultForOkIter(sqlCtx, rowIter)
} else if schema == nil {
r, err = resultForEmptyIter(sqlCtx, rowIter, resultFields)
} else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) {
r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields)
r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf)
} else {
r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more)
r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf)
}
if err != nil {
return remainder, err
Expand Down Expand Up @@ -542,7 +548,7 @@ func GetDeferredProjections(iter sql.RowIter) (sql.RowIter, []sql.Expression) {
}

// resultForMax1RowIter ensures that an empty iterator returns at most one row
func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, resultFields []*querypb.Field) (*sqltypes.Result, error) {
func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, resultFields []*querypb.Field, buf *sql.ByteBuffer) (*sqltypes.Result, error) {
defer trace.StartRegion(ctx, "Handler.resultForMax1RowIter").End()
row, err := iter.Next(ctx)
if err == io.EOF {
Expand All @@ -557,7 +563,7 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter,
if err := iter.Close(ctx); err != nil {
return nil, err
}
outputRow, err := RowToSQL(ctx, schema, row, nil)
outputRow, err := RowToSQL(ctx, schema, row, nil, buf)
if err != nil {
return nil, err
}
Expand All @@ -569,14 +575,7 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter,

// resultForDefaultIter reads batches of rows from the iterator
// and writes results into the callback function.
func (h *Handler) resultForDefaultIter(
ctx *sql.Context,
c *mysql.Conn,
schema sql.Schema,
iter sql.RowIter,
callback func(*sqltypes.Result, bool) error,
resultFields []*querypb.Field,
more bool) (r *sqltypes.Result, processedAtLeastOneBatch bool, returnErr error) {
func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema sql.Schema, iter sql.RowIter, callback func(*sqltypes.Result, bool) error, resultFields []*querypb.Field, more bool, buf *sql.ByteBuffer) (r *sqltypes.Result, processedAtLeastOneBatch bool, returnErr error) {
defer trace.StartRegion(ctx, "Handler.resultForDefaultIter").End()

eg, ctx := ctx.NewErrgroup()
Expand Down Expand Up @@ -669,7 +668,7 @@ func (h *Handler) resultForDefaultIter(
continue
}

outputRow, err := RowToSQL(ctx, schema, row, projs)
outputRow, err := RowToSQL(ctx, schema, row, projs, buf)
if err != nil {
return err
}
Expand Down Expand Up @@ -932,21 +931,35 @@ func updateMaxUsedConnectionsStatusVariable() {
}()
}

func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Expression) ([]sqltypes.Value, error) {
func toSqlHelper(ctx *sql.Context, typ sql.Type, buf *sql.ByteBuffer, val interface{}) (sqltypes.Value, error) {
if buf == nil {
return typ.SQL(ctx, buf.Get(), val)
}
spare := buf.Spare()
ret, err := typ.SQL(ctx, buf.Get(), val)
if ret.Len() > spare {
buf.Double()
} else {
buf.Advance(ret.Len())
}
return ret, err
}

func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Expression, buf *sql.ByteBuffer) ([]sqltypes.Value, error) {
// need to make sure the schema is not null as some plan schema is defined as null (e.g. IfElseBlock)
if len(sch) == 0 {
return []sqltypes.Value{}, nil
}

outVals := make([]sqltypes.Value, len(sch))
var err error
if len(projs) == 0 {
for i, col := range sch {
if row[i] == nil {
outVals[i] = sqltypes.NULL
continue
}
var err error
outVals[i], err = col.Type.SQL(ctx, nil, row[i])
outVals[i], err = toSqlHelper(ctx, col.Type, buf, row[i])
if err != nil {
return nil, err
}
Expand All @@ -963,7 +976,7 @@ func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Express
outVals[i] = sqltypes.NULL
continue
}
outVals[i], err = col.Type.SQL(ctx, nil, field)
outVals[i], err = toSqlHelper(ctx, col.Type, buf, field)
if err != nil {
return nil, err
}
Expand Down
72 changes: 72 additions & 0 deletions sql/byte_buffer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package sql

import (
"sync"
)

var SingletonBuf = NewByteBuffer(16000)

var defaultByteBuffCap = 1000

var ByteBufPool = sync.Pool{
New: func() any {
// The Pool's New function should generally only return pointer
// types, since a pointer can be put into the return interface
// value without an allocation:
return NewByteBuffer(defaultByteBuffCap)
},
}

type ByteBuffer struct {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This idea has merit and is surely better than allocating a buffer for each request but the way you're managing the memory is suboptimal. Also good to use the same backing array for multiple values in a request.

In the main use of this object in the handler, you're getting a zero-length slice (which has some larger backing array) and then calling append on it byte by byte. This will grow the backing array in some cases, but it's not being done under your deliberate control. Rather, you're then calling Double if the backing array is low on space after the appends have already happened.

Basically: in these methods, you are referring to the len of the byte slice, when your concern is usually the cap. It's fine to let append happen byte by byte as long as they aren't doubling the backing array too often, that's the expensive bit.

I think this would probably work slightly better if you just scrapped the explicit capacity management altogether and just let the Go runtime manage it automatically for you. Either that, or always manage it explicitly yourself, i.e. before you serialize a value with all those append calls. But it's not clear to me that manual management is actually any better if you use the same strategy as the go runtime does (double once we're full).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with all of this, but there are two caveats that limit our ability to let the runtime handle this for us. (1) The runtime chooses doubling based on the cap of the slice, not the full backing array. So for the current setup, the doubled array is usually actually smaller than the original backing array. (2) Doubled arrays are not reference swapped, we need a handle to the new buffer to know when to grow.

I'm not aware of how to override the default runtime growth behavior to ignore the slice cap and instead double based on the backing array cap. SoBytesBuffer still does the doubling, and a Grow(n int) interface to track when this should happen. We pay for 2 mallocs on doubling, because the first one is never big enough. Not calling Grow after allocing, or growing by too small of length compared to the allocations used will stomp previously written memory.

buf []byte
i int
}

func NewByteBuffer(initCap int) *ByteBuffer {
return &ByteBuffer{buf: make([]byte, initCap)}
}

func (b *ByteBuffer) Bytes() []byte {
return b.buf
}

func (b *ByteBuffer) GetFull(i int) []byte {
start := b.i
b.i = start + i
if b.i > len(b.buf) {
newBuf := make([]byte, len(b.buf)*2)
copy(newBuf, b.buf[:])
b.buf = newBuf
}
return b.buf[start:b.i]
}

func (b *ByteBuffer) Double() {
newBuf := make([]byte, len(b.buf)*2)
copy(newBuf, b.buf[:])
b.buf = newBuf
}

func (b *ByteBuffer) Advance(i int) {
b.i += i
}

func (b *ByteBuffer) Spare() int {
return len(b.buf) - b.i
}

func (b *ByteBuffer) Get() []byte {
//start := b.i
//b.i = start + i
//if b.i > len(b.buf) {
// newBuf := make([]byte, len(b.buf)*2)
// copy(newBuf, b.buf[:])
// b.buf = newBuf
//}
//return b.buf[start:b.i][:0]
return b.buf[b.i:b.i]
}

func (b *ByteBuffer) Reset() {
b.i = 0
}
2 changes: 1 addition & 1 deletion sql/types/bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func (t BitType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va
for i, j := 0, len(data)-1; i < j; i, j = i+1, j-1 {
data[i], data[j] = data[j], data[i]
}
val := AppendAndSliceBytes(dest, data)
val := data

return sqltypes.MakeTrusted(sqltypes.Bit, val), nil
}
Expand Down
2 changes: 1 addition & 1 deletion sql/types/datetime.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ func (t datetimeType) SQL(_ *sql.Context, dest []byte, v interface{}) (sqltypes.
return sqltypes.Value{}, sql.ErrInvalidBaseType.New(t.baseType.String(), "datetime")
}

valBytes := AppendAndSliceBytes(dest, val)
valBytes := val

return sqltypes.MakeTrusted(typ, valBytes), nil
}
Expand Down
2 changes: 1 addition & 1 deletion sql/types/enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func (t EnumType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va
snippet = strings.ToValidUTF8(snippet, string(utf8.RuneError))
return sqltypes.Value{}, sql.ErrCharSetFailedToEncode.New(resultCharset.Name(), utf8.ValidString(value), snippet)
}
val := AppendAndSliceBytes(dest, encodedBytes)
val := encodedBytes

return sqltypes.MakeTrusted(sqltypes.Enum, val), nil
}
Expand Down
2 changes: 1 addition & 1 deletion sql/types/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (t JsonType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va
if err != nil {
return sqltypes.NULL, err
}
val = AppendAndSliceBytes(dest, str)
val = str
} else {
// Convert to jsonType
jsVal, _, err := t.Convert(v)
Expand Down
2 changes: 1 addition & 1 deletion sql/types/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ func (t SetType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Val
snippet = strings.ToValidUTF8(snippet, string(utf8.RuneError))
return sqltypes.Value{}, sql.ErrCharSetFailedToEncode.New(resultCharset.Name(), utf8.ValidString(value), snippet)
}
val := AppendAndSliceBytes(dest, encodedBytes)
val := encodedBytes

return sqltypes.MakeTrusted(sqltypes.Set, val), nil
}
Expand Down
5 changes: 4 additions & 1 deletion sql/types/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ func BenchmarkNumI64SQL(b *testing.B) {
func BenchmarkVarchar10SQL(b *testing.B) {
var res sqltypes.Value
t := MustCreateStringWithDefaults(sqltypes.VarChar, 10)
buf := sql.NewByteBuffer(1000)
ctx := sql.NewEmptyContext()
for i := 0; i < b.N; i++ {
res, _ = t.SQL(ctx, nil, "char")
res, _ = t.SQL(ctx, buf.Get(), "char")
buf.Advance(res.Len())
buf.Reset()
}
result_ = res
}
Expand Down
14 changes: 8 additions & 6 deletions sql/types/strings.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,11 +524,10 @@ func (t StringType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.
start := len(dest)
var val []byte
if IsBinaryType(t) {
v, err = ConvertToBytes(v, t, dest)
val, err = ConvertToBytes(v, t, dest)
if err != nil {
return sqltypes.Value{}, err
}
val = AppendAndSliceBytes(dest, v.([]byte))
} else {
var valueBytes []byte
switch v := v.(type) {
Expand All @@ -540,7 +539,8 @@ func (t StringType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.
case []byte:
valueBytes = v
case string:
valueBytes = []byte(v)
dest = append(dest, v...)
valueBytes = dest[start:]
case int, int8, int16, int32, int64:
num, _, err := convertToInt64(Int64.(NumberTypeImpl_), v)
if err != nil {
Expand All @@ -555,10 +555,11 @@ func (t StringType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.
valueBytes = strconv.AppendUint(dest, num, 10)
case bool:
if v {
valueBytes = append(dest, '1')
dest = append(dest, '1')
} else {
valueBytes = append(dest, '0')
dest = append(dest, '0')
}
valueBytes = dest[start:]
case float64:
valueBytes = strconv.AppendFloat(dest, v, 'f', -1, 64)
if valueBytes[start] == '-' {
Expand Down Expand Up @@ -600,7 +601,8 @@ func (t StringType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.
snippetStr := strings2.ToValidUTF8(string(snippet), string(utf8.RuneError))
return sqltypes.Value{}, sql.ErrCharSetFailedToEncode.New(resultCharset.Name(), utf8.ValidString(snippetStr), snippet)
}
val = AppendAndSliceBytes(dest, encodedBytes)
//val = AppendAndSliceBytes(dest, encodedBytes)
val = encodedBytes
}

return sqltypes.MakeTrusted(t.baseType, val), nil
Expand Down
2 changes: 1 addition & 1 deletion sql/types/time.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ func (t TimespanType_) SQL(_ *sql.Context, dest []byte, v interface{}) (sqltypes
return sqltypes.Value{}, err
}

val := AppendAndSliceBytes(dest, ti.Bytes())
val := ti.Bytes()
return sqltypes.MakeTrusted(sqltypes.Time, val), nil
}

Expand Down
Loading