Skip to content

Commit a231fb5

Browse files
authored
Implement sql.ValueRow (#3248)
1 parent 594c1e0 commit a231fb5

38 files changed

+3289
-639
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ require (
1414
github.com/lestrrat-go/strftime v1.0.4
1515
github.com/pkg/errors v0.9.1
1616
github.com/pmezard/go-difflib v1.0.0
17-
github.com/shopspring/decimal v1.3.1
17+
github.com/shopspring/decimal v1.4.0
1818
github.com/sirupsen/logrus v1.8.1
1919
github.com/stretchr/testify v1.9.0
2020
go.opentelemetry.io/otel v1.31.0

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
6666
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
6767
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
6868
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
69-
github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8=
70-
github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
69+
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
70+
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
7171
github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE=
7272
github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
7373
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

server/handler.go

Lines changed: 179 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,8 @@ func (h *Handler) doQuery(
495495
r, err = resultForEmptyIter(sqlCtx, rowIter, resultFields)
496496
} else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) {
497497
r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf)
498+
} else if vr, ok := rowIter.(sql.ValueRowIter); ok && vr.IsValueRowIter(sqlCtx) {
499+
r, processedAtLeastOneBatch, err = h.resultForValueRowIter(sqlCtx, c, schema, vr, resultFields, buf, callback, more)
498500
} else {
499501
r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf)
500502
}
@@ -768,6 +770,149 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
768770
return r, processedAtLeastOneBatch, nil
769771
}
770772

773+
func (h *Handler) resultForValueRowIter(ctx *sql.Context, c *mysql.Conn, schema sql.Schema, iter sql.ValueRowIter, resultFields []*querypb.Field, buf *sql.ByteBuffer, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, bool, error) {
774+
defer trace.StartRegion(ctx, "Handler.resultForValueRowIter").End()
775+
776+
eg, ctx := ctx.NewErrgroup()
777+
pan2err := func(err *error) {
778+
if recoveredPanic := recover(); recoveredPanic != nil {
779+
wrappedErr := fmt.Errorf("handler caught panic: %v\n%s", recoveredPanic, debug.Stack())
780+
*err = goerrors.Join(*err, wrappedErr)
781+
}
782+
}
783+
784+
// TODO: poll for closed connections should obviously also run even if
785+
// we're doing something with an OK result or a single row result, etc.
786+
// This should be in the caller.
787+
pollCtx, cancelF := ctx.NewSubContext()
788+
eg.Go(func() (err error) {
789+
defer pan2err(&err)
790+
return h.pollForClosedConnection(pollCtx, c)
791+
})
792+
793+
// Default waitTime is one minute if there is no timeout configured, in which case
794+
// it will loop to iterate again unless the socket died by the OS timeout or other problems.
795+
// If there is a timeout, it will be enforced to ensure that Vitess has a chance to
796+
// call Handler.CloseConnection()
797+
waitTime := 1 * time.Minute
798+
if h.readTimeout > 0 {
799+
waitTime = h.readTimeout
800+
}
801+
timer := time.NewTimer(waitTime)
802+
defer timer.Stop()
803+
804+
wg := sync.WaitGroup{}
805+
wg.Add(2)
806+
807+
// Wrap the callback to include a BytesBuffer.Reset() for non-cursor requests, to
808+
// clean out rows that have already been spooled.
809+
resetCallback := func(r *sqltypes.Result, more bool) error {
810+
// A server-side cursor allows the caller to fetch results cached on the server-side,
811+
// so if a cursor exists, we can't release the buffer memory yet.
812+
if c.StatusFlags&uint16(mysql.ServerCursorExists) != 0 {
813+
defer buf.Reset()
814+
}
815+
return callback(r, more)
816+
}
817+
818+
// TODO: send results instead of rows?
819+
// Read rows from iter and send them off
820+
var rowChan = make(chan sql.ValueRow, 512)
821+
eg.Go(func() (err error) {
822+
defer pan2err(&err)
823+
defer wg.Done()
824+
defer close(rowChan)
825+
for {
826+
select {
827+
case <-ctx.Done():
828+
return context.Cause(ctx)
829+
default:
830+
row, err := iter.NextValueRow(ctx)
831+
if err == io.EOF {
832+
return nil
833+
}
834+
if err != nil {
835+
return err
836+
}
837+
select {
838+
case rowChan <- row:
839+
case <-ctx.Done():
840+
return nil
841+
}
842+
}
843+
}
844+
})
845+
846+
var res *sqltypes.Result
847+
var processedAtLeastOneBatch bool
848+
eg.Go(func() (err error) {
849+
defer pan2err(&err)
850+
defer cancelF()
851+
defer wg.Done()
852+
for {
853+
if res == nil {
854+
res = &sqltypes.Result{
855+
Fields: resultFields,
856+
Rows: make([][]sqltypes.Value, 0, rowsBatch),
857+
}
858+
}
859+
if res.RowsAffected == rowsBatch {
860+
if err := resetCallback(res, more); err != nil {
861+
return err
862+
}
863+
res = nil
864+
processedAtLeastOneBatch = true
865+
continue
866+
}
867+
868+
select {
869+
case <-ctx.Done():
870+
return context.Cause(ctx)
871+
case <-timer.C:
872+
if h.readTimeout != 0 {
873+
// Cancel and return so Vitess can call the CloseConnection callback
874+
ctx.GetLogger().Tracef("connection timeout")
875+
return ErrRowTimeout.New()
876+
}
877+
case row, ok := <-rowChan:
878+
if !ok {
879+
return nil
880+
}
881+
resRow, err := RowValueToSQLValues(ctx, schema, row, buf)
882+
if err != nil {
883+
return err
884+
}
885+
ctx.GetLogger().Tracef("spooling result row %s", resRow)
886+
res.Rows = append(res.Rows, resRow)
887+
res.RowsAffected++
888+
if !timer.Stop() {
889+
<-timer.C
890+
}
891+
}
892+
timer.Reset(waitTime)
893+
}
894+
})
895+
896+
// Close() kills this PID in the process list,
897+
// wait until all rows have be sent over the wire
898+
eg.Go(func() (err error) {
899+
defer pan2err(&err)
900+
wg.Wait()
901+
return iter.Close(ctx)
902+
})
903+
904+
err := eg.Wait()
905+
if err != nil {
906+
ctx.GetLogger().WithError(err).Warn("error running query")
907+
if verboseErrorLogging {
908+
fmt.Printf("Err: %+v", err)
909+
}
910+
return nil, false, err
911+
}
912+
913+
return res, processedAtLeastOneBatch, nil
914+
}
915+
771916
// See https://dev.mysql.com/doc/internals/en/status-flags.html
772917
func setConnStatusFlags(ctx *sql.Context, c *mysql.Conn) error {
773918
ok, err := isSessionAutocommit(ctx)
@@ -994,7 +1139,7 @@ func toSqlHelper(ctx *sql.Context, typ sql.Type, buf *sql.ByteBuffer, val interf
9941139
return typ.SQL(ctx, nil, val)
9951140
}
9961141
ret, err := typ.SQL(ctx, buf.Get(), val)
997-
buf.Grow(ret.Len())
1142+
buf.Grow(ret.Len()) // TODO: shouldn't we check capacity beforehand?
9981143
return ret, err
9991144
}
10001145

@@ -1037,6 +1182,39 @@ func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Express
10371182
return outVals, nil
10381183
}
10391184

1185+
func RowValueToSQLValues(ctx *sql.Context, sch sql.Schema, row sql.ValueRow, buf *sql.ByteBuffer) ([]sqltypes.Value, error) {
1186+
if len(sch) == 0 {
1187+
return []sqltypes.Value{}, nil
1188+
}
1189+
var err error
1190+
outVals := make([]sqltypes.Value, len(sch))
1191+
for i, col := range sch {
1192+
// TODO: remove this check once all Types implement this
1193+
valType, ok := col.Type.(sql.ValueType)
1194+
if !ok {
1195+
if row[i].IsNull() {
1196+
outVals[i] = sqltypes.NULL
1197+
continue
1198+
}
1199+
outVals[i] = sqltypes.MakeTrusted(row[i].Typ, row[i].Val)
1200+
continue
1201+
}
1202+
if buf == nil {
1203+
outVals[i], err = valType.SQLValue(ctx, row[i], nil)
1204+
if err != nil {
1205+
return nil, err
1206+
}
1207+
continue
1208+
}
1209+
outVals[i], err = valType.SQLValue(ctx, row[i], buf.Get())
1210+
if err != nil {
1211+
return nil, err
1212+
}
1213+
buf.Grow(outVals[i].Len())
1214+
}
1215+
return outVals, nil
1216+
}
1217+
10401218
func schemaToFields(ctx *sql.Context, s sql.Schema) []*querypb.Field {
10411219
charSetResults := ctx.GetCharacterSetResults()
10421220
fields := make([]*querypb.Field, len(s))

sql/analyzer/replace_sort.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,12 @@ func replaceIdxSortHelper(ctx *sql.Context, scope *plan.Scope, node sql.Node, so
175175
sortFields[i] = sortField
176176
} else {
177177
sameSortFields = false
178-
col2, _ := col.(sql.Expression2)
178+
valCol, _ := col.(sql.ValueExpression)
179179
sortFields[i] = sql.SortField{
180-
Column: col,
181-
Column2: col2,
182-
NullOrdering: sortField.NullOrdering,
183-
Order: sortField.Order,
180+
Column: col,
181+
ValueExprColumn: valCol,
182+
NullOrdering: sortField.NullOrdering,
183+
Order: sortField.Order,
184184
}
185185
}
186186
}

sql/cache.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ func (l *lruCache) Dispose() {
7171
}
7272

7373
type rowsCache struct {
74-
memory Freeable
75-
reporter Reporter
76-
rows []Row
77-
rows2 []Row2
74+
memory Freeable
75+
reporter Reporter
76+
rows []Row
77+
valueRows []ValueRow
7878
}
7979

8080
func newRowsCache(memory Freeable, r Reporter) *rowsCache {
@@ -92,17 +92,17 @@ func (c *rowsCache) Add(row Row) error {
9292

9393
func (c *rowsCache) Get() []Row { return c.rows }
9494

95-
func (c *rowsCache) Add2(row2 Row2) error {
95+
func (c *rowsCache) AddValueRow(row ValueRow) error {
9696
if !releaseMemoryIfNeeded(c.reporter, c.memory.Free) {
9797
return ErrNoMemoryAvailable.New()
9898
}
9999

100-
c.rows2 = append(c.rows2, row2)
100+
c.valueRows = append(c.valueRows, row)
101101
return nil
102102
}
103103

104-
func (c *rowsCache) Get2() []Row2 {
105-
return c.rows2
104+
func (c *rowsCache) GetValueRow() []ValueRow {
105+
return c.valueRows
106106
}
107107

108108
func (c *rowsCache) Dispose() {

sql/convert_value.go

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ package sql
33
import (
44
"fmt"
55

6-
"github.com/dolthub/vitess/go/vt/proto/query"
7-
86
"github.com/dolthub/go-mysql-server/sql/values"
7+
8+
"github.com/dolthub/vitess/go/vt/proto/query"
99
)
1010

1111
// ConvertToValue converts the interface to a sql value.
@@ -90,11 +90,3 @@ func ConvertToValue(v interface{}) (Value, error) {
9090
return Value{}, fmt.Errorf("type %T not implemented", v)
9191
}
9292
}
93-
94-
func MustConvertToValue(v interface{}) Value {
95-
ret, err := ConvertToValue(v)
96-
if err != nil {
97-
panic(err)
98-
}
99-
return ret
100-
}

sql/core.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -460,13 +460,13 @@ func DebugString(nodeOrExpression interface{}) string {
460460
panic(fmt.Sprintf("Expected sql.DebugString or fmt.Stringer for %T", nodeOrExpression))
461461
}
462462

463-
// Expression2 is an experimental future interface alternative to Expression to provide faster access.
464-
type Expression2 interface {
463+
// ValueExpression is an experimental future interface alternative to Expression to provide faster access.
464+
type ValueExpression interface {
465465
Expression
466-
// Eval2 evaluates the given row frame and returns a result.
467-
Eval2(ctx *Context, row Row2) (Value, error)
468-
// Type2 returns the expression type.
469-
Type2() Type2
466+
// EvalValue evaluates the given row frame and returns a result.
467+
EvalValue(ctx *Context, row ValueRow) (Value, error)
468+
// IsValueExpression indicates whether this expression and all its children support ValueExpression.
469+
IsValueExpression() bool
470470
}
471471

472472
var SystemVariables SystemVariableRegistry

0 commit comments

Comments
 (0)