diff --git a/enginetest/enginetests.go b/enginetest/enginetests.go index 1857872043..2fe9f14e58 100644 --- a/enginetest/enginetests.go +++ b/enginetest/enginetests.go @@ -5743,6 +5743,7 @@ func TestTypesOverWire(t *testing.T, harness ClientHarness, sessionBuilder serve require.NoError(t, err) expectedRowSet := script.Results[queryIdx] expectedRowIdx := 0 + buf := sql.NewByteBuffer(1000) var engineRow sql.Row for engineRow, err = engineIter.Next(ctx); err == nil; engineRow, err = engineIter.Next(ctx) { if !assert.True(t, r.Next()) { @@ -5760,7 +5761,7 @@ func TestTypesOverWire(t *testing.T, harness ClientHarness, sessionBuilder serve break } expectedEngineRow := make([]*string, len(engineRow)) - row, err := server.RowToSQL(ctx, sch, engineRow, nil) + row, err := server.RowToSQL(ctx, sch, engineRow, nil, buf) if !assert.NoError(t, err) { break } diff --git a/go.mod b/go.mod index e774bc733f..bf123fd6e6 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/dolthub/go-icu-regex v0.0.0-20241215010122-db690dd53c90 github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20241211024425-b00987f7ba54 + github.com/dolthub/vitess v0.0.0-20241220202600-b18f18d0cde7 github.com/go-kit/kit v0.10.0 github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d github.com/gocraft/dbr/v2 v2.7.2 diff --git a/go.sum b/go.sum index a06fca9d22..03f7164365 100644 --- a/go.sum +++ b/go.sum @@ -58,8 +58,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= -github.com/dolthub/vitess v0.0.0-20241211024425-b00987f7ba54 h1:nzBnC0Rt1gFtscJEz4veYd/mazZEdbdmed+tujdaKOo= -github.com/dolthub/vitess v0.0.0-20241211024425-b00987f7ba54/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= +github.com/dolthub/vitess v0.0.0-20241220202600-b18f18d0cde7 h1:w130WLeARGGNYWmhGPugsHXzJEelKKimt3kTWg6/Puk= +github.com/dolthub/vitess v0.0.0-20241220202600-b18f18d0cde7/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= diff --git a/server/handler.go b/server/handler.go index 6dc7625f4f..9bec7fc071 100644 --- a/server/handler.go +++ b/server/handler.go @@ -442,15 +442,21 @@ func (h *Handler) doQuery( var r *sqltypes.Result var processedAtLeastOneBatch bool + buf := sql.ByteBufPool.Get().(*sql.ByteBuffer) + defer func() { + buf.Reset() + sql.ByteBufPool.Put(buf) + }() + // zero/single return schema use spooling shortcut if types.IsOkResultSchema(schema) { r, err = resultForOkIter(sqlCtx, rowIter) } else if schema == nil { r, err = resultForEmptyIter(sqlCtx, rowIter, resultFields) } else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) { - r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields) + r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf) } else { - r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more) + r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf) } if err != nil { return remainder, err @@ -542,7 +548,7 @@ func GetDeferredProjections(iter sql.RowIter) (sql.RowIter, []sql.Expression) { } // resultForMax1RowIter ensures that an empty iterator returns at most one row -func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, resultFields []*querypb.Field) (*sqltypes.Result, error) { +func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, resultFields []*querypb.Field, buf *sql.ByteBuffer) (*sqltypes.Result, error) { defer trace.StartRegion(ctx, "Handler.resultForMax1RowIter").End() row, err := iter.Next(ctx) if err == io.EOF { @@ -557,7 +563,7 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, if err := iter.Close(ctx); err != nil { return nil, err } - outputRow, err := RowToSQL(ctx, schema, row, nil) + outputRow, err := RowToSQL(ctx, schema, row, nil, buf) if err != nil { return nil, err } @@ -569,14 +575,7 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, // resultForDefaultIter reads batches of rows from the iterator // and writes results into the callback function. -func (h *Handler) resultForDefaultIter( - ctx *sql.Context, - c *mysql.Conn, - schema sql.Schema, - iter sql.RowIter, - callback func(*sqltypes.Result, bool) error, - resultFields []*querypb.Field, - more bool) (r *sqltypes.Result, processedAtLeastOneBatch bool, returnErr error) { +func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema sql.Schema, iter sql.RowIter, callback func(*sqltypes.Result, bool) error, resultFields []*querypb.Field, more bool, buf *sql.ByteBuffer) (r *sqltypes.Result, processedAtLeastOneBatch bool, returnErr error) { defer trace.StartRegion(ctx, "Handler.resultForDefaultIter").End() eg, ctx := ctx.NewErrgroup() @@ -669,7 +668,7 @@ func (h *Handler) resultForDefaultIter( continue } - outputRow, err := RowToSQL(ctx, schema, row, projs) + outputRow, err := RowToSQL(ctx, schema, row, projs, buf) if err != nil { return err } @@ -932,21 +931,30 @@ func updateMaxUsedConnectionsStatusVariable() { }() } -func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Expression) ([]sqltypes.Value, error) { +func toSqlHelper(ctx *sql.Context, typ sql.Type, buf *sql.ByteBuffer, val interface{}) (sqltypes.Value, error) { + if buf == nil { + return typ.SQL(ctx, nil, val) + } + ret, err := typ.SQL(ctx, buf.Get(), val) + buf.Grow(ret.Len()) + return ret, err +} + +func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Expression, buf *sql.ByteBuffer) ([]sqltypes.Value, error) { // need to make sure the schema is not null as some plan schema is defined as null (e.g. IfElseBlock) if len(sch) == 0 { return []sqltypes.Value{}, nil } outVals := make([]sqltypes.Value, len(sch)) + var err error if len(projs) == 0 { for i, col := range sch { if row[i] == nil { outVals[i] = sqltypes.NULL continue } - var err error - outVals[i], err = col.Type.SQL(ctx, nil, row[i]) + outVals[i], err = toSqlHelper(ctx, col.Type, buf, row[i]) if err != nil { return nil, err } @@ -963,7 +971,7 @@ func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Express outVals[i] = sqltypes.NULL continue } - outVals[i], err = col.Type.SQL(ctx, nil, field) + outVals[i], err = toSqlHelper(ctx, col.Type, buf, field) if err != nil { return nil, err } diff --git a/server/handler_test.go b/server/handler_test.go index 6aac1c4c22..6bc3f1e20b 100644 --- a/server/handler_test.go +++ b/server/handler_test.go @@ -20,6 +20,7 @@ import ( "io" "net" "strconv" + "strings" "sync" "testing" "time" @@ -789,13 +790,14 @@ func TestHandlerKillQuery(t *testing.T) { var wg sync.WaitGroup wg.Add(1) - sleepQuery := "SELECT SLEEP(1)" + sleepQuery := "SELECT SLEEP(100000)" + var sleepErr error go func() { defer wg.Done() - err = handler.ComQuery(context.Background(), conn1, sleepQuery, func(res *sqltypes.Result, more bool) error { + // need a local |err| variable to avoid being overwritten + sleepErr = handler.ComQuery(context.Background(), conn1, sleepQuery, func(res *sqltypes.Result, more bool) error { return nil }) - require.Error(err) }() time.Sleep(100 * time.Millisecond) @@ -805,12 +807,17 @@ func TestHandlerKillQuery(t *testing.T) { // 2, , , test, Query, 0, running, SHOW PROCESSLIST require.Equal(2, len(res.Rows)) hasSleepQuery := false + fmt.Println(res.Rows[0][0], res.Rows[0][4], res.Rows[0][7]) + fmt.Println(res.Rows[1][0], res.Rows[1][4], res.Rows[1][7]) for _, row := range res.Rows { if row[7].ToString() != sleepQuery { continue } hasSleepQuery = true - sleepQueryID = row[0].ToString() + // the values inside a callback are generally only valid for the + // duration of the query, and need to be copied to avoid being + // overwritten + sleepQueryID = strings.Clone(row[0].ToString()) require.Equal("Query", row[4].ToString()) } require.True(hasSleepQuery) @@ -824,6 +831,7 @@ func TestHandlerKillQuery(t *testing.T) { }) require.NoError(err) wg.Wait() + require.Error(sleepErr) time.Sleep(100 * time.Millisecond) err = handler.ComQuery(context.Background(), conn2, "SHOW PROCESSLIST", func(res *sqltypes.Result, more bool) error { diff --git a/sql/byte_buffer.go b/sql/byte_buffer.go new file mode 100644 index 0000000000..03a977f209 --- /dev/null +++ b/sql/byte_buffer.go @@ -0,0 +1,75 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "sync" +) + +const defaultByteBuffCap = 4096 + +var ByteBufPool = sync.Pool{ + New: func() any { + return NewByteBuffer(defaultByteBuffCap) + }, +} + +type ByteBuffer struct { + i int + buf []byte +} + +func NewByteBuffer(initCap int) *ByteBuffer { + buf := make([]byte, initCap) + return &ByteBuffer{buf: buf} +} + +// Grow records the latest used byte position. Callers +// are responsible for accurately reporting which bytes +// they expect to be protected. +func (b *ByteBuffer) Grow(n int) { + newI := b.i + if b.i+n <= len(b.buf) { + // Increment |b.i| if no alloc + newI += n + } + if b.i+n >= len(b.buf) { + // No more space, double. + // An external allocation doubled the cap using the size of + // the override object, which if used could lead to overall + // shrinking behavior. + b.Double() + } + b.i = newI +} + +// Double expands the backing array by 2x. We do this +// here because the runtime only doubles based on slice +// length. +func (b *ByteBuffer) Double() { + buf := make([]byte, len(b.buf)*2) + copy(buf, b.buf) + b.buf = buf +} + +// Get returns a zero length slice beginning at a safe +// write position. +func (b *ByteBuffer) Get() []byte { + return b.buf[b.i:b.i] +} + +func (b *ByteBuffer) Reset() { + b.i = 0 +} diff --git a/sql/byte_buffer_test.go b/sql/byte_buffer_test.go new file mode 100644 index 0000000000..afe67aa1b7 --- /dev/null +++ b/sql/byte_buffer_test.go @@ -0,0 +1,72 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGrowByteBuffer(t *testing.T) { + b := NewByteBuffer(10) + + // grow less than boundary + src1 := []byte{1, 1, 1} + obj1 := append(b.Get(), src1...) + b.Grow(len(src1)) + + require.Equal(t, 10, len(b.buf)) + require.Equal(t, 3, b.i) + require.Equal(t, 10, cap(obj1)) + + // grow to boundary + src2 := []byte{0, 0, 0, 0, 0, 0, 0} + obj2 := append(b.Get(), src2...) + b.Grow(len(src2)) + + require.Equal(t, 20, len(b.buf)) + require.Equal(t, 10, b.i) + require.Equal(t, 7, cap(obj2)) + + src3 := []byte{2, 2, 2, 2, 2} + obj3 := append(b.Get(), src3...) + b.Grow(len(src3)) + + require.Equal(t, 20, len(b.buf)) + require.Equal(t, 15, b.i) + require.Equal(t, 10, cap(obj3)) + + // grow exceeds boundary + + src4 := []byte{3, 3, 3, 3, 3, 3, 3, 3} + obj4 := append(b.Get(), src4...) + b.Grow(len(src4)) + + require.Equal(t, 40, len(b.buf)) + require.Equal(t, 15, b.i) + require.Equal(t, 16, cap(obj4)) + + // objects are all valid after doubling + require.Equal(t, src1, obj1) + require.Equal(t, src2, obj2) + require.Equal(t, src3, obj3) + require.Equal(t, src4, obj4) + + // reset + b.Reset() + require.Equal(t, 40, len(b.buf)) + require.Equal(t, 0, b.i) +} diff --git a/sql/types/bit.go b/sql/types/bit.go index b8be7a1c0b..fe71464aaa 100644 --- a/sql/types/bit.go +++ b/sql/types/bit.go @@ -214,7 +214,7 @@ func (t BitType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va for i, j := 0, len(data)-1; i < j; i, j = i+1, j-1 { data[i], data[j] = data[j], data[i] } - val := AppendAndSliceBytes(dest, data) + val := data return sqltypes.MakeTrusted(sqltypes.Bit, val), nil } diff --git a/sql/types/datetime.go b/sql/types/datetime.go index a956b71382..68f8f7db41 100644 --- a/sql/types/datetime.go +++ b/sql/types/datetime.go @@ -410,7 +410,7 @@ func (t datetimeType) SQL(_ *sql.Context, dest []byte, v interface{}) (sqltypes. return sqltypes.Value{}, sql.ErrInvalidBaseType.New(t.baseType.String(), "datetime") } - valBytes := AppendAndSliceBytes(dest, val) + valBytes := val return sqltypes.MakeTrusted(typ, valBytes), nil } diff --git a/sql/types/enum.go b/sql/types/enum.go index 150412be2c..62fdc07585 100644 --- a/sql/types/enum.go +++ b/sql/types/enum.go @@ -257,7 +257,7 @@ func (t EnumType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va snippet = strings.ToValidUTF8(snippet, string(utf8.RuneError)) return sqltypes.Value{}, sql.ErrCharSetFailedToEncode.New(resultCharset.Name(), utf8.ValidString(value), snippet) } - val := AppendAndSliceBytes(dest, encodedBytes) + val := encodedBytes return sqltypes.MakeTrusted(sqltypes.Enum, val), nil } diff --git a/sql/types/json.go b/sql/types/json.go index 3587ed8926..adffb907d2 100644 --- a/sql/types/json.go +++ b/sql/types/json.go @@ -140,7 +140,7 @@ func (t JsonType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va if err != nil { return sqltypes.NULL, err } - val = AppendAndSliceBytes(dest, str) + val = str } else { // Convert to jsonType jsVal, _, err := t.Convert(v) diff --git a/sql/types/number.go b/sql/types/number.go index 41ae6d4671..6398c0ddae 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -613,6 +613,8 @@ func (t NumberTypeImpl_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqlt default: return sqltypes.Value{}, err } + } else if err != nil { + return sqltypes.Value{}, err } val := dest[stop:] diff --git a/sql/types/set.go b/sql/types/set.go index d67c3fe0c3..fa49d2ce29 100644 --- a/sql/types/set.go +++ b/sql/types/set.go @@ -246,7 +246,7 @@ func (t SetType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Val snippet = strings.ToValidUTF8(snippet, string(utf8.RuneError)) return sqltypes.Value{}, sql.ErrCharSetFailedToEncode.New(resultCharset.Name(), utf8.ValidString(value), snippet) } - val := AppendAndSliceBytes(dest, encodedBytes) + val := encodedBytes return sqltypes.MakeTrusted(sqltypes.Set, val), nil } diff --git a/sql/types/sql_test.go b/sql/types/sql_test.go index e691426ae9..5c90f4ddf0 100644 --- a/sql/types/sql_test.go +++ b/sql/types/sql_test.go @@ -24,9 +24,12 @@ func BenchmarkNumI64SQL(b *testing.B) { func BenchmarkVarchar10SQL(b *testing.B) { var res sqltypes.Value t := MustCreateStringWithDefaults(sqltypes.VarChar, 10) + buf := sql.NewByteBuffer(1000) ctx := sql.NewEmptyContext() for i := 0; i < b.N; i++ { - res, _ = t.SQL(ctx, nil, "char") + res, _ = t.SQL(ctx, buf.Get(), "char") + buf.Grow(res.Len()) + buf.Reset() } result_ = res } diff --git a/sql/types/strings.go b/sql/types/strings.go index d51b7e47da..75890fe90a 100644 --- a/sql/types/strings.go +++ b/sql/types/strings.go @@ -524,11 +524,10 @@ func (t StringType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes. start := len(dest) var val []byte if IsBinaryType(t) { - v, err = ConvertToBytes(v, t, dest) + val, err = ConvertToBytes(v, t, dest) if err != nil { return sqltypes.Value{}, err } - val = AppendAndSliceBytes(dest, v.([]byte)) } else { var valueBytes []byte switch v := v.(type) { @@ -540,7 +539,8 @@ func (t StringType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes. case []byte: valueBytes = v case string: - valueBytes = []byte(v) + dest = append(dest, v...) + valueBytes = dest[start:] case int, int8, int16, int32, int64: num, _, err := convertToInt64(Int64.(NumberTypeImpl_), v) if err != nil { @@ -555,10 +555,11 @@ func (t StringType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes. valueBytes = strconv.AppendUint(dest, num, 10) case bool: if v { - valueBytes = append(dest, '1') + dest = append(dest, '1') } else { - valueBytes = append(dest, '0') + dest = append(dest, '0') } + valueBytes = dest[start:] case float64: valueBytes = strconv.AppendFloat(dest, v, 'f', -1, 64) if valueBytes[start] == '-' { @@ -600,7 +601,8 @@ func (t StringType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes. snippetStr := strings2.ToValidUTF8(string(snippet), string(utf8.RuneError)) return sqltypes.Value{}, sql.ErrCharSetFailedToEncode.New(resultCharset.Name(), utf8.ValidString(snippetStr), snippet) } - val = AppendAndSliceBytes(dest, encodedBytes) + //val = AppendAndSliceBytes(dest, encodedBytes) + val = encodedBytes } return sqltypes.MakeTrusted(t.baseType, val), nil diff --git a/sql/types/time.go b/sql/types/time.go index f42cb7b535..e2bc4b1a03 100644 --- a/sql/types/time.go +++ b/sql/types/time.go @@ -271,7 +271,7 @@ func (t TimespanType_) SQL(_ *sql.Context, dest []byte, v interface{}) (sqltypes return sqltypes.Value{}, err } - val := AppendAndSliceBytes(dest, ti.Bytes()) + val := ti.Bytes() return sqltypes.MakeTrusted(sqltypes.Time, val), nil }