Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion enginetest/enginetests.go
Original file line number Diff line number Diff line change
Expand Up @@ -5743,6 +5743,7 @@ func TestTypesOverWire(t *testing.T, harness ClientHarness, sessionBuilder serve
require.NoError(t, err)
expectedRowSet := script.Results[queryIdx]
expectedRowIdx := 0
buf := sql.NewByteBuffer(1000)
var engineRow sql.Row
for engineRow, err = engineIter.Next(ctx); err == nil; engineRow, err = engineIter.Next(ctx) {
if !assert.True(t, r.Next()) {
Expand All @@ -5760,7 +5761,7 @@ func TestTypesOverWire(t *testing.T, harness ClientHarness, sessionBuilder serve
break
}
expectedEngineRow := make([]*string, len(engineRow))
row, err := server.RowToSQL(ctx, sch, engineRow, nil)
row, err := server.RowToSQL(ctx, sch, engineRow, nil, buf)
if !assert.NoError(t, err) {
break
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ require (
github.com/dolthub/go-icu-regex v0.0.0-20241215010122-db690dd53c90
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81
github.com/dolthub/vitess v0.0.0-20241211024425-b00987f7ba54
github.com/dolthub/vitess v0.0.0-20241220202600-b18f18d0cde7
github.com/go-kit/kit v0.10.0
github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d
github.com/gocraft/dbr/v2 v2.7.2
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY=
github.com/dolthub/vitess v0.0.0-20241211024425-b00987f7ba54 h1:nzBnC0Rt1gFtscJEz4veYd/mazZEdbdmed+tujdaKOo=
github.com/dolthub/vitess v0.0.0-20241211024425-b00987f7ba54/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70=
github.com/dolthub/vitess v0.0.0-20241220202600-b18f18d0cde7 h1:w130WLeARGGNYWmhGPugsHXzJEelKKimt3kTWg6/Puk=
github.com/dolthub/vitess v0.0.0-20241220202600-b18f18d0cde7/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
Expand Down
42 changes: 25 additions & 17 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,15 +442,21 @@ func (h *Handler) doQuery(
var r *sqltypes.Result
var processedAtLeastOneBatch bool

buf := sql.ByteBufPool.Get().(*sql.ByteBuffer)
defer func() {
buf.Reset()
sql.ByteBufPool.Put(buf)
}()

// zero/single return schema use spooling shortcut
if types.IsOkResultSchema(schema) {
r, err = resultForOkIter(sqlCtx, rowIter)
} else if schema == nil {
r, err = resultForEmptyIter(sqlCtx, rowIter, resultFields)
} else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) {
r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields)
r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf)
} else {
r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more)
r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf)
}
if err != nil {
return remainder, err
Expand Down Expand Up @@ -542,7 +548,7 @@ func GetDeferredProjections(iter sql.RowIter) (sql.RowIter, []sql.Expression) {
}

// resultForMax1RowIter ensures that an empty iterator returns at most one row
func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, resultFields []*querypb.Field) (*sqltypes.Result, error) {
func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, resultFields []*querypb.Field, buf *sql.ByteBuffer) (*sqltypes.Result, error) {
defer trace.StartRegion(ctx, "Handler.resultForMax1RowIter").End()
row, err := iter.Next(ctx)
if err == io.EOF {
Expand All @@ -557,7 +563,7 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter,
if err := iter.Close(ctx); err != nil {
return nil, err
}
outputRow, err := RowToSQL(ctx, schema, row, nil)
outputRow, err := RowToSQL(ctx, schema, row, nil, buf)
if err != nil {
return nil, err
}
Expand All @@ -569,14 +575,7 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter,

// resultForDefaultIter reads batches of rows from the iterator
// and writes results into the callback function.
func (h *Handler) resultForDefaultIter(
ctx *sql.Context,
c *mysql.Conn,
schema sql.Schema,
iter sql.RowIter,
callback func(*sqltypes.Result, bool) error,
resultFields []*querypb.Field,
more bool) (r *sqltypes.Result, processedAtLeastOneBatch bool, returnErr error) {
func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema sql.Schema, iter sql.RowIter, callback func(*sqltypes.Result, bool) error, resultFields []*querypb.Field, more bool, buf *sql.ByteBuffer) (r *sqltypes.Result, processedAtLeastOneBatch bool, returnErr error) {
defer trace.StartRegion(ctx, "Handler.resultForDefaultIter").End()

eg, ctx := ctx.NewErrgroup()
Expand Down Expand Up @@ -669,7 +668,7 @@ func (h *Handler) resultForDefaultIter(
continue
}

outputRow, err := RowToSQL(ctx, schema, row, projs)
outputRow, err := RowToSQL(ctx, schema, row, projs, buf)
if err != nil {
return err
}
Expand Down Expand Up @@ -932,21 +931,30 @@ func updateMaxUsedConnectionsStatusVariable() {
}()
}

func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Expression) ([]sqltypes.Value, error) {
func toSqlHelper(ctx *sql.Context, typ sql.Type, buf *sql.ByteBuffer, val interface{}) (sqltypes.Value, error) {
if buf == nil {
return typ.SQL(ctx, nil, val)
}
ret, err := typ.SQL(ctx, buf.Get(), val)
buf.Grow(ret.Len())
return ret, err
}

func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Expression, buf *sql.ByteBuffer) ([]sqltypes.Value, error) {
// need to make sure the schema is not null as some plan schema is defined as null (e.g. IfElseBlock)
if len(sch) == 0 {
return []sqltypes.Value{}, nil
}

outVals := make([]sqltypes.Value, len(sch))
var err error
if len(projs) == 0 {
for i, col := range sch {
if row[i] == nil {
outVals[i] = sqltypes.NULL
continue
}
var err error
outVals[i], err = col.Type.SQL(ctx, nil, row[i])
outVals[i], err = toSqlHelper(ctx, col.Type, buf, row[i])
if err != nil {
return nil, err
}
Expand All @@ -963,7 +971,7 @@ func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Express
outVals[i] = sqltypes.NULL
continue
}
outVals[i], err = col.Type.SQL(ctx, nil, field)
outVals[i], err = toSqlHelper(ctx, col.Type, buf, field)
if err != nil {
return nil, err
}
Expand Down
16 changes: 12 additions & 4 deletions server/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"io"
"net"
"strconv"
"strings"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -789,13 +790,14 @@ func TestHandlerKillQuery(t *testing.T) {

var wg sync.WaitGroup
wg.Add(1)
sleepQuery := "SELECT SLEEP(1)"
sleepQuery := "SELECT SLEEP(100000)"
var sleepErr error
go func() {
defer wg.Done()
err = handler.ComQuery(context.Background(), conn1, sleepQuery, func(res *sqltypes.Result, more bool) error {
// need a local |err| variable to avoid being overwritten
sleepErr = handler.ComQuery(context.Background(), conn1, sleepQuery, func(res *sqltypes.Result, more bool) error {
return nil
})
require.Error(err)
}()

time.Sleep(100 * time.Millisecond)
Expand All @@ -805,12 +807,17 @@ func TestHandlerKillQuery(t *testing.T) {
// 2, , , test, Query, 0, running, SHOW PROCESSLIST
require.Equal(2, len(res.Rows))
hasSleepQuery := false
fmt.Println(res.Rows[0][0], res.Rows[0][4], res.Rows[0][7])
fmt.Println(res.Rows[1][0], res.Rows[1][4], res.Rows[1][7])
for _, row := range res.Rows {
if row[7].ToString() != sleepQuery {
continue
}
hasSleepQuery = true
sleepQueryID = row[0].ToString()
// the values inside a callback are generally only valid for the
// duration of the query, and need to be copied to avoid being
// overwritten
sleepQueryID = strings.Clone(row[0].ToString())
require.Equal("Query", row[4].ToString())
}
require.True(hasSleepQuery)
Expand All @@ -824,6 +831,7 @@ func TestHandlerKillQuery(t *testing.T) {
})
require.NoError(err)
wg.Wait()
require.Error(sleepErr)

time.Sleep(100 * time.Millisecond)
err = handler.ComQuery(context.Background(), conn2, "SHOW PROCESSLIST", func(res *sqltypes.Result, more bool) error {
Expand Down
57 changes: 57 additions & 0 deletions sql/byte_buffer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package sql

import (
"sync"
)

const defaultByteBuffCap = 4096

var ByteBufPool = sync.Pool{
New: func() any {
return NewByteBuffer(defaultByteBuffCap)
},
}

type ByteBuffer struct {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This idea has merit and is surely better than allocating a buffer for each request but the way you're managing the memory is suboptimal. Also good to use the same backing array for multiple values in a request.

In the main use of this object in the handler, you're getting a zero-length slice (which has some larger backing array) and then calling append on it byte by byte. This will grow the backing array in some cases, but it's not being done under your deliberate control. Rather, you're then calling Double if the backing array is low on space after the appends have already happened.

Basically: in these methods, you are referring to the len of the byte slice, when your concern is usually the cap. It's fine to let append happen byte by byte as long as they aren't doubling the backing array too often, that's the expensive bit.

I think this would probably work slightly better if you just scrapped the explicit capacity management altogether and just let the Go runtime manage it automatically for you. Either that, or always manage it explicitly yourself, i.e. before you serialize a value with all those append calls. But it's not clear to me that manual management is actually any better if you use the same strategy as the go runtime does (double once we're full).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with all of this, but there are two caveats that limit our ability to let the runtime handle this for us. (1) The runtime chooses doubling based on the cap of the slice, not the full backing array. So for the current setup, the doubled array is usually actually smaller than the original backing array. (2) Doubled arrays are not reference swapped, we need a handle to the new buffer to know when to grow.

I'm not aware of how to override the default runtime growth behavior to ignore the slice cap and instead double based on the backing array cap. SoBytesBuffer still does the doubling, and a Grow(n int) interface to track when this should happen. We pay for 2 mallocs on doubling, because the first one is never big enough. Not calling Grow after allocing, or growing by too small of length compared to the allocations used will stomp previously written memory.

i int
buf []byte
}

func NewByteBuffer(initCap int) *ByteBuffer {
buf := make([]byte, initCap)
return &ByteBuffer{buf: buf}
}

// Grow records the latest used byte position. Callers
// are responsible for accurately reporting which bytes
// they expect to be protected.
func (b *ByteBuffer) Grow(n int) {
if b.i+n > len(b.buf) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually is there an off-by-one error here? Test at the boundary would be goo

Copy link
Contributor Author

@max-hoffman max-hoffman Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, not an error maybe but definitely better to preemptively double when we know the next Grow will exceed the buffer

// Runtime alloc'd into a separate backing array, but it chooses
// the doubling cap using the non-optimal |cap(b.buf)-b.i|*2.
// We do not need to increment |b.i| b/c the latest value is in
// the other array.
b.Double()
} else {
b.i += n
}
}

// Double expands the backing array by 2x. We do this
// here because the runtime only doubles based on slice
// length.
func (b *ByteBuffer) Double() {
buf := make([]byte, len(b.buf)*2)
copy(buf, b.buf)
b.buf = buf
}

// Get returns a zero length slice beginning at a safe
// write position.
func (b *ByteBuffer) Get() []byte {
return b.buf[b.i:b.i]
}

func (b *ByteBuffer) Reset() {
b.i = 0
}
2 changes: 1 addition & 1 deletion sql/types/bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func (t BitType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va
for i, j := 0, len(data)-1; i < j; i, j = i+1, j-1 {
data[i], data[j] = data[j], data[i]
}
val := AppendAndSliceBytes(dest, data)
val := data

return sqltypes.MakeTrusted(sqltypes.Bit, val), nil
}
Expand Down
2 changes: 1 addition & 1 deletion sql/types/datetime.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ func (t datetimeType) SQL(_ *sql.Context, dest []byte, v interface{}) (sqltypes.
return sqltypes.Value{}, sql.ErrInvalidBaseType.New(t.baseType.String(), "datetime")
}

valBytes := AppendAndSliceBytes(dest, val)
valBytes := val

return sqltypes.MakeTrusted(typ, valBytes), nil
}
Expand Down
2 changes: 1 addition & 1 deletion sql/types/enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func (t EnumType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va
snippet = strings.ToValidUTF8(snippet, string(utf8.RuneError))
return sqltypes.Value{}, sql.ErrCharSetFailedToEncode.New(resultCharset.Name(), utf8.ValidString(value), snippet)
}
val := AppendAndSliceBytes(dest, encodedBytes)
val := encodedBytes

return sqltypes.MakeTrusted(sqltypes.Enum, val), nil
}
Expand Down
2 changes: 1 addition & 1 deletion sql/types/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (t JsonType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va
if err != nil {
return sqltypes.NULL, err
}
val = AppendAndSliceBytes(dest, str)
val = str
} else {
// Convert to jsonType
jsVal, _, err := t.Convert(v)
Expand Down
2 changes: 2 additions & 0 deletions sql/types/number.go
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,8 @@ func (t NumberTypeImpl_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqlt
default:
return sqltypes.Value{}, err
}
} else if err != nil {
return sqltypes.Value{}, err
}

val := dest[stop:]
Expand Down
2 changes: 1 addition & 1 deletion sql/types/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ func (t SetType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Val
snippet = strings.ToValidUTF8(snippet, string(utf8.RuneError))
return sqltypes.Value{}, sql.ErrCharSetFailedToEncode.New(resultCharset.Name(), utf8.ValidString(value), snippet)
}
val := AppendAndSliceBytes(dest, encodedBytes)
val := encodedBytes

return sqltypes.MakeTrusted(sqltypes.Set, val), nil
}
Expand Down
5 changes: 4 additions & 1 deletion sql/types/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ func BenchmarkNumI64SQL(b *testing.B) {
func BenchmarkVarchar10SQL(b *testing.B) {
var res sqltypes.Value
t := MustCreateStringWithDefaults(sqltypes.VarChar, 10)
buf := sql.NewByteBuffer(1000)
ctx := sql.NewEmptyContext()
for i := 0; i < b.N; i++ {
res, _ = t.SQL(ctx, nil, "char")
res, _ = t.SQL(ctx, buf.Get(), "char")
buf.Grow(res.Len())
buf.Reset()
}
result_ = res
}
Expand Down
14 changes: 8 additions & 6 deletions sql/types/strings.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,11 +524,10 @@ func (t StringType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.
start := len(dest)
var val []byte
if IsBinaryType(t) {
v, err = ConvertToBytes(v, t, dest)
val, err = ConvertToBytes(v, t, dest)
if err != nil {
return sqltypes.Value{}, err
}
val = AppendAndSliceBytes(dest, v.([]byte))
} else {
var valueBytes []byte
switch v := v.(type) {
Expand All @@ -540,7 +539,8 @@ func (t StringType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.
case []byte:
valueBytes = v
case string:
valueBytes = []byte(v)
dest = append(dest, v...)
valueBytes = dest[start:]
case int, int8, int16, int32, int64:
num, _, err := convertToInt64(Int64.(NumberTypeImpl_), v)
if err != nil {
Expand All @@ -555,10 +555,11 @@ func (t StringType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.
valueBytes = strconv.AppendUint(dest, num, 10)
case bool:
if v {
valueBytes = append(dest, '1')
dest = append(dest, '1')
} else {
valueBytes = append(dest, '0')
dest = append(dest, '0')
}
valueBytes = dest[start:]
case float64:
valueBytes = strconv.AppendFloat(dest, v, 'f', -1, 64)
if valueBytes[start] == '-' {
Expand Down Expand Up @@ -600,7 +601,8 @@ func (t StringType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.
snippetStr := strings2.ToValidUTF8(string(snippet), string(utf8.RuneError))
return sqltypes.Value{}, sql.ErrCharSetFailedToEncode.New(resultCharset.Name(), utf8.ValidString(snippetStr), snippet)
}
val = AppendAndSliceBytes(dest, encodedBytes)
//val = AppendAndSliceBytes(dest, encodedBytes)
val = encodedBytes
}

return sqltypes.MakeTrusted(t.baseType, val), nil
Expand Down
2 changes: 1 addition & 1 deletion sql/types/time.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ func (t TimespanType_) SQL(_ *sql.Context, dest []byte, v interface{}) (sqltypes
return sqltypes.Value{}, err
}

val := AppendAndSliceBytes(dest, ti.Bytes())
val := ti.Bytes()
return sqltypes.MakeTrusted(sqltypes.Time, val), nil
}

Expand Down
Loading