From 192987e09402465c9432269e166a23c84dfbe97b Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 1 Oct 2025 10:53:59 -0700 Subject: [PATCH 01/59] implement row2 --- server/handler.go | 28 +++++++++++++ sql/plan/process.go | 23 +++++++++++ sql/rowexec/transaction_iters.go | 16 ++++++++ sql/rows.go | 8 +++- sql/table_iter.go | 70 ++++++++++++++++++++++++++++++++ sql/type.go | 1 - 6 files changed, 144 insertions(+), 2 deletions(-) diff --git a/server/handler.go b/server/handler.go index 2275ca7a2d..701c1141ca 100644 --- a/server/handler.go +++ b/server/handler.go @@ -495,6 +495,8 @@ func (h *Handler) doQuery( r, err = resultForEmptyIter(sqlCtx, rowIter, resultFields) } else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) { r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf) + } else if ri2, ok := rowIter.(sql.RowIter2); ok && ri2.IsRowIter2(sqlCtx) { + r, err = h.resultForDefaultIter2(sqlCtx, ri2, resultFields, callback, more) } else { r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf) } @@ -768,6 +770,32 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s return r, processedAtLeastOneBatch, nil } +func (h *Handler) resultForDefaultIter2(ctx *sql.Context, iter sql.RowIter2, resultFields []*querypb.Field, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, error) { + res := &sqltypes.Result{Fields: resultFields} + for { + if res.RowsAffected == rowsBatch { + if err := callback(res, more); err != nil { + return nil, err + } + res = nil + } + row, err := iter.Next2(ctx) + if err == io.EOF { + return res, nil + } + if err != nil { + return nil, err + } + + outRow := make([]sqltypes.Value, len(res.Rows)) + for i := range row { + outRow[i] = sqltypes.MakeTrusted(row[i].Typ, row[i].Val) + } + res.Rows = append(res.Rows, outRow) + res.RowsAffected++ + } +} + // See https://dev.mysql.com/doc/internals/en/status-flags.html func setConnStatusFlags(ctx *sql.Context, c *mysql.Conn) error { ok, err := isSessionAutocommit(ctx) diff --git a/sql/plan/process.go b/sql/plan/process.go index ee95249f10..92f33ba19f 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -317,6 +317,29 @@ func (i *TrackedRowIter) Next(ctx *sql.Context) (sql.Row, error) { return row, nil } +func (i *TrackedRowIter) Next2(ctx *sql.Context) (sql.Row2, error) { + ri2, ok := i.iter.(sql.RowIter2) + if !ok { + panic(fmt.Sprintf("%T does not implement sql.RowIter2 interface", i.iter)) + } + row, err := ri2.Next2(ctx) + if err != nil { + return nil, err + } + i.numRows++ + if i.onNext != nil { + i.onNext() + } + return row, nil +} + +func (i *TrackedRowIter) IsRowIter2(ctx *sql.Context) bool { + if ri2, ok := i.iter.(sql.RowIter2); ok { + return ri2.IsRowIter2(ctx) + } + return false +} + func (i *TrackedRowIter) Close(ctx *sql.Context) error { err := i.iter.Close(ctx) diff --git a/sql/rowexec/transaction_iters.go b/sql/rowexec/transaction_iters.go index db69cf5327..aacc1f095d 100644 --- a/sql/rowexec/transaction_iters.go +++ b/sql/rowexec/transaction_iters.go @@ -15,6 +15,7 @@ package rowexec import ( + "fmt" "io" "gopkg.in/src-d/go-errors.v1" @@ -99,6 +100,21 @@ func (t *TransactionCommittingIter) Next(ctx *sql.Context) (sql.Row, error) { return t.childIter.Next(ctx) } +func (t *TransactionCommittingIter) Next2(ctx *sql.Context) (sql.Row2, error) { + ri2, ok := t.childIter.(sql.RowIter2) + if !ok { + panic(fmt.Sprintf("%T does not implement sql.RowIter2 interface", t.childIter)) + } + return ri2.Next2(ctx) +} + +func (t *TransactionCommittingIter) IsRowIter2(ctx *sql.Context) bool { + if ri2, ok := t.childIter.(sql.RowIter2); ok { + return ri2.IsRowIter2(ctx) + } + return false +} + func (t *TransactionCommittingIter) Close(ctx *sql.Context) error { var err error if t.childIter != nil { diff --git a/sql/rows.go b/sql/rows.go index a9e5f55d5c..191147ad68 100644 --- a/sql/rows.go +++ b/sql/rows.go @@ -92,6 +92,12 @@ type RowIter interface { Closer } +type RowIter2 interface { + RowIter + Next2(ctx *Context) (Row2, error) + IsRowIter2(ctx *Context) bool +} + // RowIterToRows converts a row iterator to a slice of rows. func RowIterToRows(ctx *Context, i RowIter) ([]Row, error) { var rows []Row @@ -112,7 +118,7 @@ func RowIterToRows(ctx *Context, i RowIter) ([]Row, error) { return rows, i.Close(ctx) } -func rowFromRow2(sch Schema, r Row2) Row { +func RowFromRow2(sch Schema, r Row2) Row { row := make(Row, len(sch)) for i, col := range sch { switch col.Type.Type() { diff --git a/sql/table_iter.go b/sql/table_iter.go index e302d5428a..6ac205c377 100644 --- a/sql/table_iter.go +++ b/sql/table_iter.go @@ -15,6 +15,7 @@ package sql import ( + "fmt" "io" ) @@ -24,6 +25,8 @@ type TableRowIter struct { partitions PartitionIter partition Partition rows RowIter + + rows2 RowIter2 } var _ RowIter = (*TableRowIter)(nil) @@ -76,6 +79,73 @@ func (i *TableRowIter) Next(ctx *Context) (Row, error) { return row, err } +func (i *TableRowIter) Next2(ctx *Context) (Row2, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + if i.partition == nil { + partition, err := i.partitions.Next(ctx) + if err != nil { + if err == io.EOF { + if e := i.partitions.Close(ctx); e != nil { + return nil, e + } + } + + return nil, err + } + + i.partition = partition + } + + if i.rows2 == nil { + rows, err := i.table.PartitionRows(ctx, i.partition) + if err != nil { + return nil, err + } + ri2, ok := rows.(RowIter2) + if !ok { + panic(fmt.Sprintf("%T does not implement RowIter2", rows)) + } + i.rows2 = ri2 + } + + row, err := i.rows2.Next2(ctx) + if err != nil && err == io.EOF { + if err = i.rows2.Close(ctx); err != nil { + return nil, err + } + i.partition = nil + i.rows2 = nil + row, err = i.Next2(ctx) + } + return row, err +} + +func (i *TableRowIter) IsRowIter2(ctx *Context) bool { + if i.partition == nil { + partition, err := i.partitions.Next(ctx) + if err != nil { + return false + } + i.partition = partition + } + if i.rows2 == nil { + rows, err := i.table.PartitionRows(ctx, i.partition) + if err != nil { + return false + } + ri2, ok := rows.(RowIter2) + if !ok { + return false + } + i.rows2 = ri2 + } + return i.rows2.IsRowIter2(ctx) +} + func (i *TableRowIter) Close(ctx *Context) error { if i.rows != nil { if err := i.rows.Close(ctx); err != nil { diff --git a/sql/type.go b/sql/type.go index 59af5360f1..e4d0f8ff96 100644 --- a/sql/type.go +++ b/sql/type.go @@ -294,7 +294,6 @@ func IsDecimalType(t Type) bool { type Type2 interface { Type - // Compare2 returns an integer comparing two Values. Compare2(Value, Value) (int, error) // Convert2 converts a value of a compatible type. From 7a2a4ec33619d4d89cfa8f3f09902c3fb5ce3ca7 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 1 Oct 2025 16:37:57 -0700 Subject: [PATCH 02/59] better --- server/handler.go | 117 +++++++++++++++++++++++++++++++++++++-------- sql/plan/filter.go | 24 ++++++++++ 2 files changed, 121 insertions(+), 20 deletions(-) diff --git a/server/handler.go b/server/handler.go index 701c1141ca..bd82c27dfa 100644 --- a/server/handler.go +++ b/server/handler.go @@ -496,7 +496,7 @@ func (h *Handler) doQuery( } else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) { r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf) } else if ri2, ok := rowIter.(sql.RowIter2); ok && ri2.IsRowIter2(sqlCtx) { - r, err = h.resultForDefaultIter2(sqlCtx, ri2, resultFields, callback, more) + r, processedAtLeastOneBatch, err = h.resultForDefaultIter2(sqlCtx, c, ri2, resultFields, callback, more) } else { r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf) } @@ -770,30 +770,107 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s return r, processedAtLeastOneBatch, nil } -func (h *Handler) resultForDefaultIter2(ctx *sql.Context, iter sql.RowIter2, resultFields []*querypb.Field, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, error) { - res := &sqltypes.Result{Fields: resultFields} - for { - if res.RowsAffected == rowsBatch { - if err := callback(res, more); err != nil { - return nil, err - } - res = nil - } - row, err := iter.Next2(ctx) - if err == io.EOF { - return res, nil +func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sql.RowIter2, resultFields []*querypb.Field, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, bool, error) { + defer trace.StartRegion(ctx, "Handler.resultForDefaultIter").End() + + eg, ctx := ctx.NewErrgroup() + pan2err := func(err *error) { + if recoveredPanic := recover(); recoveredPanic != nil { + stack := debug.Stack() + wrappedErr := fmt.Errorf("handler caught panic: %v\n%s", recoveredPanic, stack) + *err = goerrors.Join(*err, wrappedErr) } - if err != nil { - return nil, err + } + + // TODO: poll for closed connections should obviously also run even if + // we're doing something with an OK result or a single row result, etc. + // This should be in the caller. + pollCtx, cancelF := ctx.NewSubContext() + eg.Go(func() (err error) { + defer pan2err(&err) + return h.pollForClosedConnection(pollCtx, c) + }) + + // Default waitTime is one minute if there is no timeout configured, in which case + // it will loop to iterate again unless the socket died by the OS timeout or other problems. + // If there is a timeout, it will be enforced to ensure that Vitess has a chance to + // call Handler.CloseConnection() + waitTime := 1 * time.Minute + if h.readTimeout > 0 { + waitTime = h.readTimeout + } + timer := time.NewTimer(waitTime) + defer timer.Stop() + + wg := sync.WaitGroup{} + wg.Add(1) + + var res *sqltypes.Result + var processedAtLeastOneBatch bool + eg.Go(func() (err error) { + defer pan2err(&err) + defer cancelF() + defer wg.Done() + for { + if res == nil { + res = &sqltypes.Result{Fields: resultFields} + } + if res.RowsAffected == rowsBatch { + if err := callback(res, more); err != nil { + return err + } + res = nil + processedAtLeastOneBatch = true + continue + } + + select { + case <-ctx.Done(): + return context.Cause(ctx) + case <-timer.C: + // TODO: timer should probably go in its own thread, as rowChan is blocking + if h.readTimeout != 0 { + // Cancel and return so Vitess can call the CloseConnection callback + ctx.GetLogger().Tracef("connection timeout") + return ErrRowTimeout.New() + } + default: + row, err := iter.Next2(ctx) + if err == io.EOF { + return nil + } + if err != nil { + return err + } + outRow := make([]sqltypes.Value, len(row)) + for i := range row { + outRow[i] = sqltypes.MakeTrusted(row[i].Typ, row[i].Val) + } + res.Rows = append(res.Rows, outRow) + res.RowsAffected++ + } + timer.Reset(waitTime) } + }) - outRow := make([]sqltypes.Value, len(res.Rows)) - for i := range row { - outRow[i] = sqltypes.MakeTrusted(row[i].Typ, row[i].Val) + // Close() kills this PID in the process list, + // wait until all rows have be sent over the wire + eg.Go(func() (err error) { + defer pan2err(&err) + wg.Wait() + return iter.Close(ctx) + }) + + err := eg.Wait() + if err != nil { + ctx.GetLogger().WithError(err).Warn("error running query") + if verboseErrorLogging { + fmt.Printf("Err: %+v", err) } - res.Rows = append(res.Rows, outRow) - res.RowsAffected++ + return nil, false, err } + + return res, processedAtLeastOneBatch, nil } // See https://dev.mysql.com/doc/internals/en/status-flags.html diff --git a/sql/plan/filter.go b/sql/plan/filter.go index f2c0691112..57c64664df 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -15,6 +15,7 @@ package plan import ( + "fmt" "github.com/dolthub/go-mysql-server/sql" ) @@ -133,6 +134,29 @@ func (i *FilterIter) Next(ctx *sql.Context) (sql.Row, error) { } } +func (i *FilterIter) Next2(ctx *sql.Context) (sql.Row2, error) { + ri2, ok := i.childIter.(sql.RowIter2) + if !ok { + panic(fmt.Sprintf("%T is not a sql.RowIter2", i.childIter)) + } + + for { + row, err := ri2.Next(ctx) + if err != nil { + return nil, err + } + + res, err := sql.EvaluateCondition(ctx, i.cond, row) + if err != nil { + return nil, err + } + + if sql.IsTrue(res) { + return nil, nil + } + } +} + // Close implements the RowIter interface. func (i *FilterIter) Close(ctx *sql.Context) error { return i.childIter.Close(ctx) From 993a68834f9a62c9ac9310f9895f0d36eaa5b62b Mon Sep 17 00:00:00 2001 From: jycor Date: Thu, 2 Oct 2025 17:59:53 +0000 Subject: [PATCH 03/59] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/plan/filter.go | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/plan/filter.go b/sql/plan/filter.go index 57c64664df..8be34ec2c7 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -16,6 +16,7 @@ package plan import ( "fmt" + "github.com/dolthub/go-mysql-server/sql" ) From 6deaf9d3f78a09c5f78819cfbffcf089cc7cf7b2 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 10 Oct 2025 13:28:31 -0700 Subject: [PATCH 04/59] disable row2 --- server/handler.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/handler.go b/server/handler.go index bd82c27dfa..cc2621f54c 100644 --- a/server/handler.go +++ b/server/handler.go @@ -495,8 +495,8 @@ func (h *Handler) doQuery( r, err = resultForEmptyIter(sqlCtx, rowIter, resultFields) } else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) { r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf) - } else if ri2, ok := rowIter.(sql.RowIter2); ok && ri2.IsRowIter2(sqlCtx) { - r, processedAtLeastOneBatch, err = h.resultForDefaultIter2(sqlCtx, c, ri2, resultFields, callback, more) + //} else if ri2, ok := rowIter.(sql.RowIter2); ok && ri2.IsRowIter2(sqlCtx) { + // r, processedAtLeastOneBatch, err = h.resultForDefaultIter2(sqlCtx, c, ri2, resultFields, callback, more) } else { r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf) } From da6b39c1c5efc24662fdbbff3d71293166d9115e Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 10 Oct 2025 16:28:23 -0700 Subject: [PATCH 05/59] reenable row2 --- server/handler.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/handler.go b/server/handler.go index cc2621f54c..bd82c27dfa 100644 --- a/server/handler.go +++ b/server/handler.go @@ -495,8 +495,8 @@ func (h *Handler) doQuery( r, err = resultForEmptyIter(sqlCtx, rowIter, resultFields) } else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) { r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf) - //} else if ri2, ok := rowIter.(sql.RowIter2); ok && ri2.IsRowIter2(sqlCtx) { - // r, processedAtLeastOneBatch, err = h.resultForDefaultIter2(sqlCtx, c, ri2, resultFields, callback, more) + } else if ri2, ok := rowIter.(sql.RowIter2); ok && ri2.IsRowIter2(sqlCtx) { + r, processedAtLeastOneBatch, err = h.resultForDefaultIter2(sqlCtx, c, ri2, resultFields, callback, more) } else { r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf) } From c6d99b4c65af452a96441598b33beb4d1595c839 Mon Sep 17 00:00:00 2001 From: James Cor Date: Sun, 12 Oct 2025 17:52:13 -0700 Subject: [PATCH 06/59] implement expr2 for filters --- sql/core.go | 1 + sql/expression/comparison.go | 56 ++++++++++++++++++++++++++++++++++++ sql/expression/get_field.go | 4 +++ sql/expression/literal.go | 4 +++ sql/expression/unresolved.go | 4 +++ sql/plan/filter.go | 24 +++++++++++++--- 6 files changed, 89 insertions(+), 4 deletions(-) diff --git a/sql/core.go b/sql/core.go index c1e1f90b2a..c2996039eb 100644 --- a/sql/core.go +++ b/sql/core.go @@ -467,6 +467,7 @@ type Expression2 interface { Eval2(ctx *Context, row Row2) (Value, error) // Type2 returns the expression type. Type2() Type2 + IsExpr2() bool } var SystemVariables SystemVariableRegistry diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index e7cb0b8e15..6e4fee29ae 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -15,7 +15,9 @@ package expression import ( + "bytes" "fmt" + querypb "github.com/dolthub/vitess/go/vt/proto/query" errors "gopkg.in/src-d/go-errors.v1" @@ -521,6 +523,60 @@ func (gt *GreaterThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) return result == 1, nil } +func (gt *GreaterThan) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { + l, ok := gt.Left().(sql.Expression2) + if !ok { + panic(fmt.Sprintf("%T does not implement sql.Expression2", gt.Left())) + } + r, ok := gt.Right().(sql.Expression2) + if !ok { + panic(fmt.Sprintf("%T does not implement sql.Expression2", gt.Right())) + } + + lv, err := l.Eval2(ctx, row) + if err != nil { + return sql.Value{}, nil + } + rv, err := r.Eval2(ctx, row) + if err != nil { + return sql.Value{}, nil + } + + // TODO: better implementation + res := bytes.Compare(lv.Val, rv.Val) // TODO: this is probably wrong + var rb byte + if res == 1 { + rb = 1 + } + ret := sql.Value{ + Val: sql.ValueBytes{rb}, + Typ: querypb.Type_INT8, + } + return ret, nil +} + +func (gt *GreaterThan) Type2() sql.Type2 { + return nil +} + +func (gt *GreaterThan) IsExpr2() bool { + lExpr, isExpr2 := gt.Left().(sql.Expression2) + if !isExpr2 { + return false + } + if !lExpr.IsExpr2() { + return false + } + rExpr, isExpr2 := gt.Right().(sql.Expression2) + if !isExpr2 { + return false + } + if !rExpr.IsExpr2() { + return false + } + return true +} + // WithChildren implements the Expression interface. func (gt *GreaterThan) WithChildren(children ...sql.Expression) (sql.Expression, error) { if len(children) != 2 { diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index f4ff9b429e..5e9263760f 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -157,6 +157,10 @@ func (p *GetField) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { return row.GetField(p.fieldIndex), nil } +func (p *GetField) IsExpr2() bool { + return true +} + // WithChildren implements the Expression interface. func (p *GetField) WithChildren(children ...sql.Expression) (sql.Expression, error) { if len(children) != 0 { diff --git a/sql/expression/literal.go b/sql/expression/literal.go index 8fff9557a7..2b5583dc5b 100644 --- a/sql/expression/literal.go +++ b/sql/expression/literal.go @@ -140,6 +140,10 @@ func (lit *Literal) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { return lit.val2, nil } +func (lit *Literal) IsExpr2() bool { + return true +} + func (lit *Literal) Type2() sql.Type2 { t2, ok := lit.Typ.(sql.Type2) if !ok { diff --git a/sql/expression/unresolved.go b/sql/expression/unresolved.go index 78e2e9d0b9..c421699722 100644 --- a/sql/expression/unresolved.go +++ b/sql/expression/unresolved.go @@ -79,6 +79,10 @@ func (uc *UnresolvedColumn) Type2() sql.Type2 { panic("unresolved column is a placeholder node, but Type2 was called") } +func (uc *UnresolvedColumn) IsExpr2() bool { + panic("unresolved column is a placeholder node, but IsExpr2 was called") +} + // Name implements the Nameable interface. func (uc *UnresolvedColumn) Name() string { return uc.name } diff --git a/sql/plan/filter.go b/sql/plan/filter.go index 8be34ec2c7..b9edb0aa41 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -142,20 +142,36 @@ func (i *FilterIter) Next2(ctx *sql.Context) (sql.Row2, error) { } for { - row, err := ri2.Next(ctx) + row, err := ri2.Next2(ctx) if err != nil { return nil, err } - res, err := sql.EvaluateCondition(ctx, i.cond, row) + // TODO: write EvaluateCondition2? + cond, isCond2 := i.cond.(sql.Expression2) + if !isCond2 { + panic(fmt.Sprintf("%T does not implement sql.Expression2 interface", i.cond)) + } + res, err := cond.Eval2(ctx, row) if err != nil { return nil, err } + if res.Val[0] == 1 { + return row, nil + } + } +} - if sql.IsTrue(res) { - return nil, nil +func (i *FilterIter) IsRowIter2(ctx *sql.Context) bool { + if cond, isExpr2 := i.cond.(sql.Expression2); isExpr2 { + if !cond.IsExpr2() { + return false } } + if ri2, ok := i.childIter.(sql.RowIter2); ok { + return ri2.IsRowIter2(ctx) + } + return false } // Close implements the RowIter interface. From 6dd030997066e15965b1ccba01d92b1a9303032e Mon Sep 17 00:00:00 2001 From: jycor Date: Mon, 13 Oct 2025 00:55:24 +0000 Subject: [PATCH 07/59] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/expression/comparison.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 6e4fee29ae..c93c403f63 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -17,8 +17,8 @@ package expression import ( "bytes" "fmt" - querypb "github.com/dolthub/vitess/go/vt/proto/query" + querypb "github.com/dolthub/vitess/go/vt/proto/query" errors "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" From f9c35e86456fa109c764c886984932eb561a85c0 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 13 Oct 2025 01:02:58 -0700 Subject: [PATCH 08/59] reduce type asserts --- sql/plan/filter.go | 36 +++++++++++++------------------- sql/plan/process.go | 15 +++++++------ sql/rowexec/transaction_iters.go | 16 +++++++------- sql/table_iter.go | 2 +- 4 files changed, 29 insertions(+), 40 deletions(-) diff --git a/sql/plan/filter.go b/sql/plan/filter.go index b9edb0aa41..c2bf35c50e 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -15,8 +15,6 @@ package plan import ( - "fmt" - "github.com/dolthub/go-mysql-server/sql" ) @@ -106,6 +104,9 @@ func (f *Filter) Expressions() []sql.Expression { type FilterIter struct { cond sql.Expression childIter sql.RowIter + + cond2 sql.Expression2 + childIter2 sql.RowIter2 } // NewFilterIter creates a new FilterIter. @@ -136,23 +137,12 @@ func (i *FilterIter) Next(ctx *sql.Context) (sql.Row, error) { } func (i *FilterIter) Next2(ctx *sql.Context) (sql.Row2, error) { - ri2, ok := i.childIter.(sql.RowIter2) - if !ok { - panic(fmt.Sprintf("%T is not a sql.RowIter2", i.childIter)) - } - for { - row, err := ri2.Next2(ctx) + row, err := i.childIter2.Next2(ctx) if err != nil { return nil, err } - - // TODO: write EvaluateCondition2? - cond, isCond2 := i.cond.(sql.Expression2) - if !isCond2 { - panic(fmt.Sprintf("%T does not implement sql.Expression2 interface", i.cond)) - } - res, err := cond.Eval2(ctx, row) + res, err := i.cond2.Eval2(ctx, row) if err != nil { return nil, err } @@ -163,15 +153,17 @@ func (i *FilterIter) Next2(ctx *sql.Context) (sql.Row2, error) { } func (i *FilterIter) IsRowIter2(ctx *sql.Context) bool { - if cond, isExpr2 := i.cond.(sql.Expression2); isExpr2 { - if !cond.IsExpr2() { - return false - } + cond, ok := i.cond.(sql.Expression2) + if !ok || !cond.IsExpr2() { + return false } - if ri2, ok := i.childIter.(sql.RowIter2); ok { - return ri2.IsRowIter2(ctx) + childIter, ok := i.childIter.(sql.RowIter2) + if !ok || !childIter.IsRowIter2(ctx) { + return false } - return false + i.cond2 = cond + i.childIter2 = childIter + return true } // Close implements the RowIter interface. diff --git a/sql/plan/process.go b/sql/plan/process.go index 92f33ba19f..70a687247f 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -226,6 +226,7 @@ const ( type TrackedRowIter struct { node sql.Node iter sql.RowIter + iter2 sql.RowIter2 onDone NotifyFunc onNext NotifyFunc numRows int64 @@ -318,11 +319,7 @@ func (i *TrackedRowIter) Next(ctx *sql.Context) (sql.Row, error) { } func (i *TrackedRowIter) Next2(ctx *sql.Context) (sql.Row2, error) { - ri2, ok := i.iter.(sql.RowIter2) - if !ok { - panic(fmt.Sprintf("%T does not implement sql.RowIter2 interface", i.iter)) - } - row, err := ri2.Next2(ctx) + row, err := i.iter2.Next2(ctx) if err != nil { return nil, err } @@ -334,10 +331,12 @@ func (i *TrackedRowIter) Next2(ctx *sql.Context) (sql.Row2, error) { } func (i *TrackedRowIter) IsRowIter2(ctx *sql.Context) bool { - if ri2, ok := i.iter.(sql.RowIter2); ok { - return ri2.IsRowIter2(ctx) + iter, ok := i.iter.(sql.RowIter2) + if !ok || !iter.IsRowIter2(ctx) { + return false } - return false + i.iter2 = iter + return true } func (i *TrackedRowIter) Close(ctx *sql.Context) error { diff --git a/sql/rowexec/transaction_iters.go b/sql/rowexec/transaction_iters.go index aacc1f095d..f0f56168ef 100644 --- a/sql/rowexec/transaction_iters.go +++ b/sql/rowexec/transaction_iters.go @@ -15,7 +15,6 @@ package rowexec import ( - "fmt" "io" "gopkg.in/src-d/go-errors.v1" @@ -72,6 +71,7 @@ func getLockableTable(table sql.Table) (sql.Lockable, error) { // during the Close() operation type TransactionCommittingIter struct { childIter sql.RowIter + childIter2 sql.RowIter2 transactionDatabase string autoCommit bool implicitCommit bool @@ -101,18 +101,16 @@ func (t *TransactionCommittingIter) Next(ctx *sql.Context) (sql.Row, error) { } func (t *TransactionCommittingIter) Next2(ctx *sql.Context) (sql.Row2, error) { - ri2, ok := t.childIter.(sql.RowIter2) - if !ok { - panic(fmt.Sprintf("%T does not implement sql.RowIter2 interface", t.childIter)) - } - return ri2.Next2(ctx) + return t.childIter2.Next2(ctx) } func (t *TransactionCommittingIter) IsRowIter2(ctx *sql.Context) bool { - if ri2, ok := t.childIter.(sql.RowIter2); ok { - return ri2.IsRowIter2(ctx) + childIter, ok := t.childIter.(sql.RowIter2) + if !ok || !childIter.IsRowIter2(ctx) { + return false } - return false + t.childIter2 = childIter + return true } func (t *TransactionCommittingIter) Close(ctx *sql.Context) error { diff --git a/sql/table_iter.go b/sql/table_iter.go index 6ac205c377..884778307a 100644 --- a/sql/table_iter.go +++ b/sql/table_iter.go @@ -106,7 +106,7 @@ func (i *TableRowIter) Next2(ctx *Context) (Row2, error) { return nil, err } ri2, ok := rows.(RowIter2) - if !ok { + if !ok || !ri2.IsRowIter2(ctx) { panic(fmt.Sprintf("%T does not implement RowIter2", rows)) } i.rows2 = ri2 From ec8b4dceff0f7cb3073e035b51be4d6e4db2802e Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 13 Oct 2025 01:09:45 -0700 Subject: [PATCH 09/59] split send and receive --- server/handler.go | 46 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/server/handler.go b/server/handler.go index bd82c27dfa..34ca72882d 100644 --- a/server/handler.go +++ b/server/handler.go @@ -771,7 +771,7 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s } func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sql.RowIter2, resultFields []*querypb.Field, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, bool, error) { - defer trace.StartRegion(ctx, "Handler.resultForDefaultIter").End() + defer trace.StartRegion(ctx, "Handler.resultForDefaultIter2").End() eg, ctx := ctx.NewErrgroup() pan2err := func(err *error) { @@ -803,7 +803,34 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq defer timer.Stop() wg := sync.WaitGroup{} - wg.Add(1) + wg.Add(2) + + // TODO: this should be merged below go func + var rowChan = make(chan sql.Row2, 512) + eg.Go(func() (err error) { + defer pan2err(&err) + defer wg.Done() + defer close(rowChan) + for { + select { + case <-ctx.Done(): + return context.Cause(ctx) + default: + row, err := iter.Next2(ctx) + if err == io.EOF { + return nil + } + if err != nil { + return err + } + select { + case rowChan <- row: + case <-ctx.Done(): + return nil + } + } + } + }) var res *sqltypes.Result var processedAtLeastOneBatch bool @@ -813,7 +840,10 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq defer wg.Done() for { if res == nil { - res = &sqltypes.Result{Fields: resultFields} + res = &sqltypes.Result{ + Fields: resultFields, + Rows: make([][]sqltypes.Value, 0, rowsBatch), + } } if res.RowsAffected == rowsBatch { if err := callback(res, more); err != nil { @@ -834,14 +864,12 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq ctx.GetLogger().Tracef("connection timeout") return ErrRowTimeout.New() } - default: - row, err := iter.Next2(ctx) - if err == io.EOF { + case row, ok := <-rowChan: + if !ok { return nil } - if err != nil { - return err - } + // TODO: we can avoid deep copy here by redefining sql.Row2 + ctx.GetLogger().Tracef("spooling result row %s", row) outRow := make([]sqltypes.Value, len(row)) for i := range row { outRow[i] = sqltypes.MakeTrusted(row[i].Typ, row[i].Val) From 84997e122cf6dbcf545d960fbd33dccf23be86bd Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 13 Oct 2025 11:52:39 -0700 Subject: [PATCH 10/59] directly return rows --- server/handler.go | 47 +++------- sql/convert_value.go | 92 +++++-------------- sql/core.go | 3 +- sql/expression/comparison.go | 28 +++--- sql/expression/get_field.go | 8 +- sql/expression/literal.go | 5 +- sql/expression/unresolved.go | 3 +- sql/plan/filter.go | 5 +- sql/row_frame.go | 15 ++-- sql/rows.go | 24 ++--- sql/type.go | 8 +- sql/types/number.go | 168 +++++++++++++---------------------- 12 files changed, 145 insertions(+), 261 deletions(-) diff --git a/server/handler.go b/server/handler.go index 34ca72882d..fe1b50d700 100644 --- a/server/handler.go +++ b/server/handler.go @@ -803,34 +803,7 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq defer timer.Stop() wg := sync.WaitGroup{} - wg.Add(2) - - // TODO: this should be merged below go func - var rowChan = make(chan sql.Row2, 512) - eg.Go(func() (err error) { - defer pan2err(&err) - defer wg.Done() - defer close(rowChan) - for { - select { - case <-ctx.Done(): - return context.Cause(ctx) - default: - row, err := iter.Next2(ctx) - if err == io.EOF { - return nil - } - if err != nil { - return err - } - select { - case rowChan <- row: - case <-ctx.Done(): - return nil - } - } - } - }) + wg.Add(1) var res *sqltypes.Result var processedAtLeastOneBatch bool @@ -864,18 +837,20 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq ctx.GetLogger().Tracef("connection timeout") return ErrRowTimeout.New() } - case row, ok := <-rowChan: - if !ok { + default: + row, err := iter.Next2(ctx) + if err == io.EOF { return nil } - // TODO: we can avoid deep copy here by redefining sql.Row2 - ctx.GetLogger().Tracef("spooling result row %s", row) - outRow := make([]sqltypes.Value, len(row)) - for i := range row { - outRow[i] = sqltypes.MakeTrusted(row[i].Typ, row[i].Val) + if err != nil { + return err } - res.Rows = append(res.Rows, outRow) + ctx.GetLogger().Tracef("spooling result row %s", row) + res.Rows = append(res.Rows, row) res.RowsAffected++ + if !timer.Stop() { + <-timer.C + } } timer.Reset(waitTime) } diff --git a/sql/convert_value.go b/sql/convert_value.go index d46fe4de4e..d64d5fbe98 100644 --- a/sql/convert_value.go +++ b/sql/convert_value.go @@ -3,98 +3,46 @@ package sql import ( "fmt" - "github.com/dolthub/vitess/go/vt/proto/query" - "github.com/dolthub/go-mysql-server/sql/values" + + "github.com/dolthub/vitess/go/sqltypes" + "github.com/dolthub/vitess/go/vt/proto/query" ) // ConvertToValue converts the interface to a sql value. -func ConvertToValue(v interface{}) (Value, error) { +func ConvertToValue(v interface{}) (sqltypes.Value, error) { switch v := v.(type) { case nil: - return Value{ - Typ: query.Type_NULL_TYPE, - Val: nil, - }, nil + return sqltypes.MakeTrusted(query.Type_NULL_TYPE, nil), nil case int: - return Value{ - Typ: query.Type_INT64, - Val: values.WriteInt64(make([]byte, values.Int64Size), int64(v)), - }, nil + return sqltypes.MakeTrusted(query.Type_INT64, values.WriteInt64(make([]byte, values.Int64Size), int64(v))), nil case int8: - return Value{ - Typ: query.Type_INT8, - Val: values.WriteInt8(make([]byte, values.Int8Size), v), - }, nil + return sqltypes.MakeTrusted(query.Type_INT8, values.WriteInt8(make([]byte, values.Int8Size), v)), nil case int16: - return Value{ - Typ: query.Type_INT16, - Val: values.WriteInt16(make([]byte, values.Int16Size), v), - }, nil + return sqltypes.MakeTrusted(query.Type_INT16, values.WriteInt16(make([]byte, values.Int16Size), v)), nil case int32: - return Value{ - Typ: query.Type_INT32, - Val: values.WriteInt32(make([]byte, values.Int32Size), v), - }, nil + return sqltypes.MakeTrusted(query.Type_INT32, values.WriteInt32(make([]byte, values.Int32Size), v)), nil case int64: - return Value{ - Typ: query.Type_INT64, - Val: values.WriteInt64(make([]byte, values.Int64Size), v), - }, nil + return sqltypes.MakeTrusted(query.Type_INT64, values.WriteInt64(make([]byte, values.Int64Size), v)), nil case uint: - return Value{ - Typ: query.Type_UINT64, - Val: values.WriteUint64(make([]byte, values.Uint64Size), uint64(v)), - }, nil + return sqltypes.MakeTrusted(query.Type_UINT64, values.WriteUint64(make([]byte, values.Uint64Size), uint64(v))), nil case uint8: - return Value{ - Typ: query.Type_UINT8, - Val: values.WriteUint8(make([]byte, values.Uint8Size), v), - }, nil + return sqltypes.MakeTrusted(query.Type_UINT8, values.WriteUint8(make([]byte, values.Uint8Size), v)), nil case uint16: - return Value{ - Typ: query.Type_UINT16, - Val: values.WriteUint16(make([]byte, values.Uint16Size), v), - }, nil + return sqltypes.MakeTrusted(query.Type_UINT16, values.WriteUint16(make([]byte, values.Uint16Size), v)), nil case uint32: - return Value{ - Typ: query.Type_UINT32, - Val: values.WriteUint32(make([]byte, values.Uint32Size), v), - }, nil + return sqltypes.MakeTrusted(query.Type_UINT32, values.WriteUint32(make([]byte, values.Uint32Size), v)), nil case uint64: - return Value{ - Typ: query.Type_UINT64, - Val: values.WriteUint64(make([]byte, values.Uint64Size), v), - }, nil + return sqltypes.MakeTrusted(query.Type_UINT64, values.WriteUint64(make([]byte, values.Uint64Size), v)), nil case float32: - return Value{ - Typ: query.Type_FLOAT32, - Val: values.WriteFloat32(make([]byte, values.Float32Size), v), - }, nil + return sqltypes.MakeTrusted(query.Type_FLOAT32, values.WriteFloat32(make([]byte, values.Float32Size), v)), nil case float64: - return Value{ - Typ: query.Type_FLOAT64, - Val: values.WriteFloat64(make([]byte, values.Float64Size), v), - }, nil + return sqltypes.MakeTrusted(query.Type_FLOAT64, values.WriteFloat64(make([]byte, values.Float64Size), v)), nil case string: - return Value{ - Typ: query.Type_VARCHAR, - Val: values.WriteString(make([]byte, len(v)), v, values.ByteOrderCollation), - }, nil + return sqltypes.MakeTrusted(query.Type_VARCHAR, values.WriteString(make([]byte, len(v)), v, values.ByteOrderCollation)), nil case []byte: - return Value{ - Typ: query.Type_BLOB, - Val: values.WriteBytes(make([]byte, len(v)), v, values.ByteOrderCollation), - }, nil + return sqltypes.MakeTrusted(query.Type_BLOB, values.WriteBytes(make([]byte, len(v)), v, values.ByteOrderCollation)), nil default: - return Value{}, fmt.Errorf("type %T not implemented", v) - } -} - -func MustConvertToValue(v interface{}) Value { - ret, err := ConvertToValue(v) - if err != nil { - panic(err) + return sqltypes.Value{}, fmt.Errorf("type %T not implemented", v) } - return ret } diff --git a/sql/core.go b/sql/core.go index c2996039eb..19a7a7a895 100644 --- a/sql/core.go +++ b/sql/core.go @@ -30,6 +30,7 @@ import ( "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql/values" + "github.com/dolthub/vitess/go/sqltypes" ) // Expression is a combination of one or more SQL expressions. @@ -464,7 +465,7 @@ func DebugString(nodeOrExpression interface{}) string { type Expression2 interface { Expression // Eval2 evaluates the given row frame and returns a result. - Eval2(ctx *Context, row Row2) (Value, error) + Eval2(ctx *Context, row Row2) (sqltypes.Value, error) // Type2 returns the expression type. Type2() Type2 IsExpr2() bool diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index c93c403f63..cc788e2cc6 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -15,8 +15,8 @@ package expression import ( - "bytes" "fmt" + "github.com/dolthub/vitess/go/sqltypes" querypb "github.com/dolthub/vitess/go/vt/proto/query" errors "gopkg.in/src-d/go-errors.v1" @@ -497,6 +497,7 @@ type GreaterThan struct { } var _ sql.Expression = (*GreaterThan)(nil) +var _ sql.Expression2 = (*GreaterThan)(nil) var _ sql.CollationCoercible = (*GreaterThan)(nil) // NewGreaterThan creates a new GreaterThan expression. @@ -523,7 +524,7 @@ func (gt *GreaterThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) return result == 1, nil } -func (gt *GreaterThan) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { +func (gt *GreaterThan) Eval2(ctx *sql.Context, row sql.Row2) (sqltypes.Value, error) { l, ok := gt.Left().(sql.Expression2) if !ok { panic(fmt.Sprintf("%T does not implement sql.Expression2", gt.Left())) @@ -535,23 +536,28 @@ func (gt *GreaterThan) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) lv, err := l.Eval2(ctx, row) if err != nil { - return sql.Value{}, nil + return sqltypes.Value{}, err } rv, err := r.Eval2(ctx, row) if err != nil { - return sql.Value{}, nil + return sqltypes.Value{}, err } - // TODO: better implementation - res := bytes.Compare(lv.Val, rv.Val) // TODO: this is probably wrong + // TODO: just assume they are int64 + l64, err := types.ConvertValueToInt64(types.NumberTypeImpl_{}, lv) + if err != nil { + return sqltypes.Value{}, err + } + r64, err := types.ConvertValueToInt64(types.NumberTypeImpl_{}, rv) + if err != nil { + return sqltypes.Value{}, err + } var rb byte - if res == 1 { + if l64 > r64 { rb = 1 } - ret := sql.Value{ - Val: sql.ValueBytes{rb}, - Typ: querypb.Type_INT8, - } + + ret := sqltypes.MakeTrusted(querypb.Type_INT8, []byte{rb}) return ret, nil } diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index 5e9263760f..398ca7107a 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -16,6 +16,7 @@ package expression import ( "fmt" + "github.com/dolthub/vitess/go/sqltypes" "strings" errors "gopkg.in/src-d/go-errors.v1" @@ -149,12 +150,11 @@ func (p *GetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return row[p.fieldIndex], nil } -func (p *GetField) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { +func (p *GetField) Eval2(ctx *sql.Context, row sql.Row2) (sqltypes.Value, error) { if p.fieldIndex < 0 || p.fieldIndex >= row.Len() { - return sql.Value{}, ErrIndexOutOfBounds.New(p.fieldIndex, row.Len()) + return sqltypes.Value{}, ErrIndexOutOfBounds.New(p.fieldIndex, row.Len()) } - - return row.GetField(p.fieldIndex), nil + return row[p.fieldIndex], nil } func (p *GetField) IsExpr2() bool { diff --git a/sql/expression/literal.go b/sql/expression/literal.go index 2b5583dc5b..b386b86412 100644 --- a/sql/expression/literal.go +++ b/sql/expression/literal.go @@ -16,6 +16,7 @@ package expression import ( "fmt" + "github.com/dolthub/vitess/go/sqltypes" "strings" "github.com/dolthub/vitess/go/vt/proto/query" @@ -30,7 +31,7 @@ import ( type Literal struct { Val interface{} Typ sql.Type - val2 sql.Value + val2 sqltypes.Value } var _ sql.Expression = &Literal{} @@ -136,7 +137,7 @@ func (*Literal) Children() []sql.Expression { return nil } -func (lit *Literal) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { +func (lit *Literal) Eval2(ctx *sql.Context, row sql.Row2) (sqltypes.Value, error) { return lit.val2, nil } diff --git a/sql/expression/unresolved.go b/sql/expression/unresolved.go index c421699722..0173651464 100644 --- a/sql/expression/unresolved.go +++ b/sql/expression/unresolved.go @@ -16,6 +16,7 @@ package expression import ( "fmt" + "github.com/dolthub/vitess/go/sqltypes" "strings" "gopkg.in/src-d/go-errors.v1" @@ -71,7 +72,7 @@ func (*UnresolvedColumn) CollationCoercibility(ctx *sql.Context) (collation sql. return sql.Collation_binary, 7 } -func (uc *UnresolvedColumn) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { +func (uc *UnresolvedColumn) Eval2(ctx *sql.Context, row sql.Row2) (sqltypes.Value, error) { panic("unresolved column is a placeholder node, but Eval2 was called") } diff --git a/sql/plan/filter.go b/sql/plan/filter.go index c2bf35c50e..0d0e4ebe39 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -109,6 +109,9 @@ type FilterIter struct { childIter2 sql.RowIter2 } +var _ sql.RowIter = (*FilterIter)(nil) +var _ sql.RowIter2 = (*FilterIter)(nil) + // NewFilterIter creates a new FilterIter. func NewFilterIter( cond sql.Expression, @@ -146,7 +149,7 @@ func (i *FilterIter) Next2(ctx *sql.Context) (sql.Row2, error) { if err != nil { return nil, err } - if res.Val[0] == 1 { + if res.Raw()[0] == 1 { return row, nil } } diff --git a/sql/row_frame.go b/sql/row_frame.go index ef3ea6010f..a4384ec458 100644 --- a/sql/row_frame.go +++ b/sql/row_frame.go @@ -17,6 +17,7 @@ package sql import ( "sync" + "github.com/dolthub/vitess/go/sqltypes" querypb "github.com/dolthub/vitess/go/vt/proto/query" ) @@ -26,10 +27,10 @@ const ( ) // Row2 is a slice of values -type Row2 []Value +type Row2 []sqltypes.Value // GetField returns the Value for the ith field in this row. -func (r Row2) GetField(i int) Value { +func (r Row2) GetField(i int) sqltypes.Value { return r[i] } @@ -97,10 +98,7 @@ func (f *RowFrame) Row2() Row2 { rs := make(Row2, len(f.Values)) for i := range f.Values { - rs[i] = Value{ - Typ: f.Types[i], - Val: f.Values[i], - } + rs[i] = sqltypes.MakeTrusted(f.Types[i], f.Values[i]) } return rs } @@ -113,10 +111,7 @@ func (f *RowFrame) Row2Copy() Row2 { for i := range f.Values { v := make(ValueBytes, len(f.Values[i])) copy(v, f.Values[i]) - rs[i] = Value{ - Typ: f.Types[i], - Val: v, - } + rs[i] = sqltypes.MakeTrusted(f.Types[i], v) } return rs } diff --git a/sql/rows.go b/sql/rows.go index 191147ad68..2a969363bc 100644 --- a/sql/rows.go +++ b/sql/rows.go @@ -123,29 +123,29 @@ func RowFromRow2(sch Schema, r Row2) Row { for i, col := range sch { switch col.Type.Type() { case query.Type_INT8: - row[i] = values.ReadInt8(r.GetField(i).Val) + row[i] = values.ReadInt8(r.GetField(i).Raw()) case query.Type_UINT8: - row[i] = values.ReadUint8(r.GetField(i).Val) + row[i] = values.ReadUint8(r.GetField(i).Raw()) case query.Type_INT16: - row[i] = values.ReadInt16(r.GetField(i).Val) + row[i] = values.ReadInt16(r.GetField(i).Raw()) case query.Type_UINT16: - row[i] = values.ReadUint16(r.GetField(i).Val) + row[i] = values.ReadUint16(r.GetField(i).Raw()) case query.Type_INT32: - row[i] = values.ReadInt32(r.GetField(i).Val) + row[i] = values.ReadInt32(r.GetField(i).Raw()) case query.Type_UINT32: - row[i] = values.ReadUint32(r.GetField(i).Val) + row[i] = values.ReadUint32(r.GetField(i).Raw()) case query.Type_INT64: - row[i] = values.ReadInt64(r.GetField(i).Val) + row[i] = values.ReadInt64(r.GetField(i).Raw()) case query.Type_UINT64: - row[i] = values.ReadUint64(r.GetField(i).Val) + row[i] = values.ReadUint64(r.GetField(i).Raw()) case query.Type_FLOAT32: - row[i] = values.ReadFloat32(r.GetField(i).Val) + row[i] = values.ReadFloat32(r.GetField(i).Raw()) case query.Type_FLOAT64: - row[i] = values.ReadFloat64(r.GetField(i).Val) + row[i] = values.ReadFloat64(r.GetField(i).Raw()) case query.Type_TEXT, query.Type_VARCHAR, query.Type_CHAR: - row[i] = values.ReadString(r.GetField(i).Val, values.ByteOrderCollation) + row[i] = values.ReadString(r.GetField(i).Raw(), values.ByteOrderCollation) case query.Type_BLOB, query.Type_VARBINARY, query.Type_BINARY: - row[i] = values.ReadBytes(r.GetField(i).Val, values.ByteOrderCollation) + row[i] = values.ReadBytes(r.GetField(i).Raw(), values.ByteOrderCollation) case query.Type_BIT: fallthrough case query.Type_ENUM: diff --git a/sql/type.go b/sql/type.go index e4d0f8ff96..285744a564 100644 --- a/sql/type.go +++ b/sql/type.go @@ -295,13 +295,11 @@ func IsDecimalType(t Type) bool { type Type2 interface { Type // Compare2 returns an integer comparing two Values. - Compare2(Value, Value) (int, error) + Compare2(sqltypes.Value, sqltypes.Value) (int, error) // Convert2 converts a value of a compatible type. - Convert2(Value) (Value, error) + Convert2(sqltypes.Value) (sqltypes.Value, error) // Zero2 returns the zero Value for this type. - Zero2() Value - // SQL2 returns the sqltypes.Value for the given value - SQL2(Value) (sqltypes.Value, error) + Zero2() sqltypes.Value } // SpatialColumnType is a node that contains a reference to all spatial types. diff --git a/sql/types/number.go b/sql/types/number.go index e9ecfc04f7..1cdb3f0b4a 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -728,7 +728,7 @@ func (t NumberTypeImpl_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqlt return sqltypes.MakeTrusted(t.baseType, val), nil } -func (t NumberTypeImpl_) Compare2(a sql.Value, b sql.Value) (int, error) { +func (t NumberTypeImpl_) Compare2(a sqltypes.Value, b sqltypes.Value) (int, error) { switch t.baseType { case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint24, sqltypes.Uint32, sqltypes.Uint64: ca, err := convertValueToUint64(t, a) @@ -765,11 +765,11 @@ func (t NumberTypeImpl_) Compare2(a sql.Value, b sql.Value) (int, error) { } return +1, nil default: - ca, err := convertValueToInt64(t, a) + ca, err := ConvertValueToInt64(t, a) if err != nil { return 0, err } - cb, err := convertValueToInt64(t, b) + cb, err := ConvertValueToInt64(t, b) if err != nil { return 0, err } @@ -784,84 +784,40 @@ func (t NumberTypeImpl_) Compare2(a sql.Value, b sql.Value) (int, error) { } } -func (t NumberTypeImpl_) Convert2(value sql.Value) (sql.Value, error) { +func (t NumberTypeImpl_) Convert2(value sqltypes.Value) (sqltypes.Value, error) { panic("implement me") } -func (t NumberTypeImpl_) Zero2() sql.Value { +func (t NumberTypeImpl_) Zero2() sqltypes.Value { switch t.baseType { case sqltypes.Int8: - x := values.WriteInt8(make([]byte, values.Int8Size), 0) - return sql.Value{ - Typ: query.Type_INT8, - Val: x, - } + return sqltypes.MakeTrusted(query.Type_INT8, make([]byte, values.Int8Size)) case sqltypes.Int16: - x := values.WriteInt16(make([]byte, values.Int16Size), 0) - return sql.Value{ - Typ: query.Type_INT16, - Val: x, - } + return sqltypes.MakeTrusted(query.Type_INT16, make([]byte, values.Int16Size)) case sqltypes.Int24: - x := values.WriteInt24(make([]byte, values.Int24Size), 0) - return sql.Value{ - Typ: query.Type_INT24, - Val: x, - } + return sqltypes.MakeTrusted(query.Type_INT24, make([]byte, values.Int24Size)) case sqltypes.Int32: - x := values.WriteInt32(make([]byte, values.Int32Size), 0) - return sql.Value{ - Typ: query.Type_INT32, - Val: x, - } + return sqltypes.MakeTrusted(query.Type_INT32, make([]byte, values.Int32Size)) case sqltypes.Int64: - x := values.WriteInt64(make([]byte, values.Int64Size), 0) - return sql.Value{ - Typ: query.Type_INT64, - Val: x, - } + return sqltypes.MakeTrusted(query.Type_INT64, make([]byte, values.Int64Size)) case sqltypes.Uint8: - x := values.WriteUint8(make([]byte, values.Uint8Size), 0) - return sql.Value{ - Typ: query.Type_UINT8, - Val: x, - } + return sqltypes.MakeTrusted(query.Type_UINT8, make([]byte, values.Uint8Size)) case sqltypes.Uint16: - x := values.WriteUint16(make([]byte, values.Uint16Size), 0) - return sql.Value{ - Typ: query.Type_UINT16, - Val: x, - } + return sqltypes.MakeTrusted(query.Type_UINT16, make([]byte, values.Uint16Size)) case sqltypes.Uint24: - x := values.WriteUint24(make([]byte, values.Uint24Size), 0) - return sql.Value{ - Typ: query.Type_UINT24, - Val: x, - } + return sqltypes.MakeTrusted(query.Type_UINT24, make([]byte, values.Uint24Size)) case sqltypes.Uint32: - x := values.WriteUint32(make([]byte, values.Uint32Size), 0) - return sql.Value{ - Typ: query.Type_UINT32, - Val: x, - } + return sqltypes.MakeTrusted(query.Type_UINT32, make([]byte, values.Uint32Size)) case sqltypes.Uint64: - x := values.WriteUint64(make([]byte, values.Uint64Size), 0) - return sql.Value{ - Typ: query.Type_UINT64, - Val: x, - } + return sqltypes.MakeTrusted(query.Type_UINT64, make([]byte, values.Uint64Size)) case sqltypes.Float32: + // TODO: 0 float32 is just 0? x := values.WriteFloat32(make([]byte, values.Float32Size), 0) - return sql.Value{ - Typ: query.Type_FLOAT32, - Val: x, - } + return sqltypes.MakeTrusted(query.Type_FLOAT32, x) case sqltypes.Float64: - x := values.WriteUint64(make([]byte, values.Uint64Size), 0) - return sql.Value{ - Typ: query.Type_UINT64, - Val: x, - } + // TODO: 0 float64 is just 0? + x := values.WriteFloat64(make([]byte, values.Float64Size), 0) + return sqltypes.MakeTrusted(query.Type_FLOAT64, x) default: panic(sql.ErrInvalidBaseType.New(t.baseType.String(), "number")) } @@ -1152,34 +1108,34 @@ func convertToInt64(t NumberTypeImpl_, v any, round Round) (int64, sql.ConvertIn } } -func convertValueToInt64(t NumberTypeImpl_, v sql.Value) (int64, error) { - switch v.Typ { +func ConvertValueToInt64(t NumberTypeImpl_, v sqltypes.Value) (int64, error) { + switch v.Type() { case query.Type_INT8: - return int64(values.ReadInt8(v.Val)), nil + return int64(values.ReadInt8(v.Raw())), nil case query.Type_INT16: - return int64(values.ReadInt16(v.Val)), nil + return int64(values.ReadInt16(v.Raw())), nil case query.Type_INT24: - return int64(values.ReadInt24(v.Val)), nil + return int64(values.ReadInt24(v.Raw())), nil case query.Type_INT32: - return int64(values.ReadInt32(v.Val)), nil + return int64(values.ReadInt32(v.Raw())), nil case query.Type_INT64: - return values.ReadInt64(v.Val), nil + return values.ReadInt64(v.Raw()), nil case query.Type_UINT8: - return int64(values.ReadUint8(v.Val)), nil + return int64(values.ReadUint8(v.Raw())), nil case query.Type_UINT16: - return int64(values.ReadUint16(v.Val)), nil + return int64(values.ReadUint16(v.Raw())), nil case query.Type_UINT24: - return int64(values.ReadUint24(v.Val)), nil + return int64(values.ReadUint24(v.Raw())), nil case query.Type_UINT32: - return int64(values.ReadUint32(v.Val)), nil + return int64(values.ReadUint32(v.Raw())), nil case query.Type_UINT64: - v := values.ReadUint64(v.Val) + v := values.ReadUint64(v.Raw()) if v > math.MaxInt64 { return math.MaxInt64, nil } return int64(v), nil case query.Type_FLOAT32: - v := values.ReadFloat32(v.Val) + v := values.ReadFloat32(v.Raw()) if v > float32(math.MaxInt64) { return math.MaxInt64, nil } else if v < float32(math.MinInt64) { @@ -1187,7 +1143,7 @@ func convertValueToInt64(t NumberTypeImpl_, v sql.Value) (int64, error) { } return int64(math.Round(float64(v))), nil case query.Type_FLOAT64: - v := values.ReadFloat64(v.Val) + v := values.ReadFloat64(v.Raw()) if v > float64(math.MaxInt64) { return math.MaxInt64, nil } else if v < float64(math.MinInt64) { @@ -1200,36 +1156,36 @@ func convertValueToInt64(t NumberTypeImpl_, v sql.Value) (int64, error) { } } -func convertValueToUint64(t NumberTypeImpl_, v sql.Value) (uint64, error) { - switch v.Typ { +func convertValueToUint64(t NumberTypeImpl_, v sqltypes.Value) (uint64, error) { + switch v.Type() { case query.Type_INT8: - return uint64(values.ReadInt8(v.Val)), nil + return uint64(values.ReadInt8(v.Raw())), nil case query.Type_INT16: - return uint64(values.ReadInt16(v.Val)), nil + return uint64(values.ReadInt16(v.Raw())), nil case query.Type_INT24: - return uint64(values.ReadInt24(v.Val)), nil + return uint64(values.ReadInt24(v.Raw())), nil case query.Type_INT32: - return uint64(values.ReadInt32(v.Val)), nil + return uint64(values.ReadInt32(v.Raw())), nil case query.Type_INT64: - return uint64(values.ReadInt64(v.Val)), nil + return uint64(values.ReadInt64(v.Raw())), nil case query.Type_UINT8: - return uint64(values.ReadUint8(v.Val)), nil + return uint64(values.ReadUint8(v.Raw())), nil case query.Type_UINT16: - return uint64(values.ReadUint16(v.Val)), nil + return uint64(values.ReadUint16(v.Raw())), nil case query.Type_UINT24: - return uint64(values.ReadUint24(v.Val)), nil + return uint64(values.ReadUint24(v.Raw())), nil case query.Type_UINT32: - return uint64(values.ReadUint32(v.Val)), nil + return uint64(values.ReadUint32(v.Raw())), nil case query.Type_UINT64: - return values.ReadUint64(v.Val), nil + return values.ReadUint64(v.Raw()), nil case query.Type_FLOAT32: - v := values.ReadFloat32(v.Val) + v := values.ReadFloat32(v.Raw()) if v >= float32(math.MaxUint64) { return math.MaxUint64, nil } return uint64(math.Round(float64(v))), nil case query.Type_FLOAT64: - v := values.ReadFloat64(v.Val) + v := values.ReadFloat64(v.Raw()) if v >= float64(math.MaxUint64) { return math.MaxUint64, nil } @@ -1428,32 +1384,32 @@ func convertToFloat64(t NumberTypeImpl_, v interface{}) (float64, error) { } } -func convertValueToFloat64(t NumberTypeImpl_, v sql.Value) (float64, error) { - switch v.Typ { +func convertValueToFloat64(t NumberTypeImpl_, v sqltypes.Value) (float64, error) { + switch v.Type() { case query.Type_INT8: - return float64(values.ReadInt8(v.Val)), nil + return float64(values.ReadInt8(v.Raw())), nil case query.Type_INT16: - return float64(values.ReadInt16(v.Val)), nil + return float64(values.ReadInt16(v.Raw())), nil case query.Type_INT24: - return float64(values.ReadInt24(v.Val)), nil + return float64(values.ReadInt24(v.Raw())), nil case query.Type_INT32: - return float64(values.ReadInt32(v.Val)), nil + return float64(values.ReadInt32(v.Raw())), nil case query.Type_INT64: - return float64(values.ReadInt64(v.Val)), nil + return float64(values.ReadInt64(v.Raw())), nil case query.Type_UINT8: - return float64(values.ReadUint8(v.Val)), nil + return float64(values.ReadUint8(v.Raw())), nil case query.Type_UINT16: - return float64(values.ReadUint16(v.Val)), nil + return float64(values.ReadUint16(v.Raw())), nil case query.Type_UINT24: - return float64(values.ReadUint24(v.Val)), nil + return float64(values.ReadUint24(v.Raw())), nil case query.Type_UINT32: - return float64(values.ReadUint32(v.Val)), nil + return float64(values.ReadUint32(v.Raw())), nil case query.Type_UINT64: - return float64(values.ReadUint64(v.Val)), nil + return float64(values.ReadUint64(v.Raw())), nil case query.Type_FLOAT32: - return float64(values.ReadFloat32(v.Val)), nil + return float64(values.ReadFloat32(v.Raw())), nil case query.Type_FLOAT64: - return values.ReadFloat64(v.Val), nil + return values.ReadFloat64(v.Raw()), nil default: panic(sql.ErrInvalidBaseType.New(t.baseType.String(), "number")) } From 310942f51cff01840b554b90654325a88a252198 Mon Sep 17 00:00:00 2001 From: jycor Date: Mon, 13 Oct 2025 18:54:34 +0000 Subject: [PATCH 11/59] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/core.go | 2 +- sql/expression/comparison.go | 2 +- sql/expression/get_field.go | 2 +- sql/expression/literal.go | 2 +- sql/expression/unresolved.go | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core.go b/sql/core.go index 19a7a7a895..ee8d6e2f4d 100644 --- a/sql/core.go +++ b/sql/core.go @@ -26,11 +26,11 @@ import ( "time" "unsafe" + "github.com/dolthub/vitess/go/sqltypes" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql/values" - "github.com/dolthub/vitess/go/sqltypes" ) // Expression is a combination of one or more SQL expressions. diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index cc788e2cc6..6d7a67d790 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -16,8 +16,8 @@ package expression import ( "fmt" - "github.com/dolthub/vitess/go/sqltypes" + "github.com/dolthub/vitess/go/sqltypes" querypb "github.com/dolthub/vitess/go/vt/proto/query" errors "gopkg.in/src-d/go-errors.v1" diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index 398ca7107a..2611858867 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -16,9 +16,9 @@ package expression import ( "fmt" - "github.com/dolthub/vitess/go/sqltypes" "strings" + "github.com/dolthub/vitess/go/sqltypes" errors "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" diff --git a/sql/expression/literal.go b/sql/expression/literal.go index b386b86412..1411ca0830 100644 --- a/sql/expression/literal.go +++ b/sql/expression/literal.go @@ -16,9 +16,9 @@ package expression import ( "fmt" - "github.com/dolthub/vitess/go/sqltypes" "strings" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/shopspring/decimal" diff --git a/sql/expression/unresolved.go b/sql/expression/unresolved.go index 0173651464..a0df5ab12f 100644 --- a/sql/expression/unresolved.go +++ b/sql/expression/unresolved.go @@ -16,9 +16,9 @@ package expression import ( "fmt" - "github.com/dolthub/vitess/go/sqltypes" "strings" + "github.com/dolthub/vitess/go/sqltypes" "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" From cd838a91927ad00066e01d7aa0a697d781534864 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 13 Oct 2025 15:25:08 -0700 Subject: [PATCH 12/59] resplit --- server/handler.go | 38 ++++++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/server/handler.go b/server/handler.go index fe1b50d700..09399c09cf 100644 --- a/server/handler.go +++ b/server/handler.go @@ -803,7 +803,34 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq defer timer.Stop() wg := sync.WaitGroup{} - wg.Add(1) + wg.Add(2) + + // Read rows from iter and send them off + var rowChan = make(chan sql.Row2, 512) + eg.Go(func() (err error) { + defer pan2err(&err) + defer wg.Done() + defer close(rowChan) + for { + select { + case <-ctx.Done(): + return context.Cause(ctx) + default: + row, err := iter.Next2(ctx) + if err == io.EOF { + return nil + } + if err != nil { + return err + } + select { + case rowChan <- row: + case <-ctx.Done(): + return nil + } + } + } + }) var res *sqltypes.Result var processedAtLeastOneBatch bool @@ -831,20 +858,15 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq case <-ctx.Done(): return context.Cause(ctx) case <-timer.C: - // TODO: timer should probably go in its own thread, as rowChan is blocking if h.readTimeout != 0 { // Cancel and return so Vitess can call the CloseConnection callback ctx.GetLogger().Tracef("connection timeout") return ErrRowTimeout.New() } - default: - row, err := iter.Next2(ctx) - if err == io.EOF { + case row, ok := <-rowChan: + if !ok { return nil } - if err != nil { - return err - } ctx.GetLogger().Tracef("spooling result row %s", row) res.Rows = append(res.Rows, row) res.RowsAffected++ From 0c3be294959ad646dfbcfbb4671f60ed194dda55 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 14 Oct 2025 02:07:50 -0700 Subject: [PATCH 13/59] TODO --- server/handler.go | 1 + 1 file changed, 1 insertion(+) diff --git a/server/handler.go b/server/handler.go index 09399c09cf..1000b98ae6 100644 --- a/server/handler.go +++ b/server/handler.go @@ -805,6 +805,7 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq wg := sync.WaitGroup{} wg.Add(2) + // TODO: send results instead of rows? // Read rows from iter and send them off var rowChan = make(chan sql.Row2, 512) eg.Go(func() (err error) { From f59647cb980540062f80ce9bdf6319dc05f8405d Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 14 Oct 2025 16:03:58 -0700 Subject: [PATCH 14/59] small fixes --- sql/expression/literal.go | 4 ++-- sql/values/encoding.go | 28 +--------------------------- 2 files changed, 3 insertions(+), 29 deletions(-) diff --git a/sql/expression/literal.go b/sql/expression/literal.go index 1411ca0830..977a9f5506 100644 --- a/sql/expression/literal.go +++ b/sql/expression/literal.go @@ -154,8 +154,8 @@ func (lit *Literal) Type2() sql.Type2 { } // Value returns the literal value. -func (p *Literal) Value() interface{} { - return p.Val +func (lit *Literal) Value() interface{} { + return lit.Val } func (lit *Literal) WithResolvedChildren(children []any) (any, error) { diff --git a/sql/values/encoding.go b/sql/values/encoding.go index 3472e870e5..d00e630091 100644 --- a/sql/values/encoding.go +++ b/sql/values/encoding.go @@ -130,11 +130,7 @@ func ReadUint16(val []byte) uint16 { func ReadInt24(val []byte) (i int32) { expectSize(val, Int24Size) - var tmp [4]byte - // copy |val| to |tmp| - tmp[3], tmp[2] = val[3], val[2] - tmp[1], tmp[0] = val[1], val[0] - i = int32(binary.LittleEndian.Uint32(tmp[:])) + i = int32(binary.LittleEndian.Uint32([]byte{0, val[0], val[1], val[2]})) return } @@ -158,28 +154,6 @@ func ReadUint32(val []byte) uint32 { return binary.LittleEndian.Uint32(val) } -func ReadInt48(val []byte) (i int64) { - expectSize(val, Int48Size) - var tmp [8]byte - // copy |val| to |tmp| - tmp[5], tmp[4] = val[5], val[4] - tmp[3], tmp[2] = val[3], val[2] - tmp[1], tmp[0] = val[1], val[0] - i = int64(binary.LittleEndian.Uint64(tmp[:])) - return -} - -func ReadUint48(val []byte) (u uint64) { - expectSize(val, Uint48Size) - var tmp [8]byte - // copy |val| to |tmp| - tmp[5], tmp[4] = val[5], val[4] - tmp[3], tmp[2] = val[3], val[2] - tmp[1], tmp[0] = val[1], val[0] - u = binary.LittleEndian.Uint64(tmp[:]) - return -} - func ReadInt64(val []byte) int64 { expectSize(val, Int64Size) return int64(binary.LittleEndian.Uint64(val)) From 83c6d0f51f8d18b8e511d5022d889c7921760163 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 15 Oct 2025 15:04:32 -0700 Subject: [PATCH 15/59] don't use vitess type --- server/handler.go | 8 +- sql/convert_value.go | 80 +++++++++++++---- sql/core.go | 3 +- sql/expression/comparison.go | 16 ++-- sql/expression/get_field.go | 5 +- sql/expression/literal.go | 5 +- sql/expression/unresolved.go | 3 +- sql/plan/filter.go | 2 +- sql/row_frame.go | 15 ++-- sql/rows.go | 24 ++--- sql/type.go | 6 +- sql/types/number.go | 164 ++++++++++++++++++++++------------- 12 files changed, 213 insertions(+), 118 deletions(-) diff --git a/server/handler.go b/server/handler.go index 1000b98ae6..1ac48b7c7f 100644 --- a/server/handler.go +++ b/server/handler.go @@ -868,8 +868,12 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq if !ok { return nil } - ctx.GetLogger().Tracef("spooling result row %s", row) - res.Rows = append(res.Rows, row) + resRow := make([]sqltypes.Value, len(row)) + for i, v := range row { + resRow[i] = sqltypes.MakeTrusted(v.Typ, v.Val) + } + ctx.GetLogger().Tracef("spooling result row %s", resRow) + res.Rows = append(res.Rows, resRow) res.RowsAffected++ if !timer.Stop() { <-timer.C diff --git a/sql/convert_value.go b/sql/convert_value.go index d64d5fbe98..880b9f2f58 100644 --- a/sql/convert_value.go +++ b/sql/convert_value.go @@ -5,44 +5,88 @@ import ( "github.com/dolthub/go-mysql-server/sql/values" - "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) // ConvertToValue converts the interface to a sql value. -func ConvertToValue(v interface{}) (sqltypes.Value, error) { +func ConvertToValue(v interface{}) (Value, error) { switch v := v.(type) { case nil: - return sqltypes.MakeTrusted(query.Type_NULL_TYPE, nil), nil + return Value{ + Typ: query.Type_NULL_TYPE, + Val: nil, + }, nil case int: - return sqltypes.MakeTrusted(query.Type_INT64, values.WriteInt64(make([]byte, values.Int64Size), int64(v))), nil + return Value{ + Typ: query.Type_INT64, + Val: values.WriteInt64(make([]byte, values.Int64Size), int64(v)), + }, nil case int8: - return sqltypes.MakeTrusted(query.Type_INT8, values.WriteInt8(make([]byte, values.Int8Size), v)), nil + return Value{ + Typ: query.Type_INT8, + Val: values.WriteInt8(make([]byte, values.Int8Size), v), + }, nil case int16: - return sqltypes.MakeTrusted(query.Type_INT16, values.WriteInt16(make([]byte, values.Int16Size), v)), nil + return Value{ + Typ: query.Type_INT16, + Val: values.WriteInt16(make([]byte, values.Int16Size), v), + }, nil case int32: - return sqltypes.MakeTrusted(query.Type_INT32, values.WriteInt32(make([]byte, values.Int32Size), v)), nil + return Value{ + Typ: query.Type_INT32, + Val: values.WriteInt32(make([]byte, values.Int32Size), v), + }, nil case int64: - return sqltypes.MakeTrusted(query.Type_INT64, values.WriteInt64(make([]byte, values.Int64Size), v)), nil + return Value{ + Typ: query.Type_INT64, + Val: values.WriteInt64(make([]byte, values.Int64Size), v), + }, nil case uint: - return sqltypes.MakeTrusted(query.Type_UINT64, values.WriteUint64(make([]byte, values.Uint64Size), uint64(v))), nil + return Value{ + Typ: query.Type_UINT64, + Val: values.WriteUint64(make([]byte, values.Uint64Size), uint64(v)), + }, nil case uint8: - return sqltypes.MakeTrusted(query.Type_UINT8, values.WriteUint8(make([]byte, values.Uint8Size), v)), nil + return Value{ + Typ: query.Type_UINT8, + Val: values.WriteUint8(make([]byte, values.Uint8Size), v), + }, nil case uint16: - return sqltypes.MakeTrusted(query.Type_UINT16, values.WriteUint16(make([]byte, values.Uint16Size), v)), nil + return Value{ + Typ: query.Type_UINT16, + Val: values.WriteUint16(make([]byte, values.Uint16Size), v), + }, nil case uint32: - return sqltypes.MakeTrusted(query.Type_UINT32, values.WriteUint32(make([]byte, values.Uint32Size), v)), nil + return Value{ + Typ: query.Type_UINT32, + Val: values.WriteUint32(make([]byte, values.Uint32Size), v), + }, nil case uint64: - return sqltypes.MakeTrusted(query.Type_UINT64, values.WriteUint64(make([]byte, values.Uint64Size), v)), nil + return Value{ + Typ: query.Type_UINT64, + Val: values.WriteUint64(make([]byte, values.Uint64Size), v), + }, nil case float32: - return sqltypes.MakeTrusted(query.Type_FLOAT32, values.WriteFloat32(make([]byte, values.Float32Size), v)), nil + return Value{ + Typ: query.Type_FLOAT32, + Val: values.WriteFloat32(make([]byte, values.Float32Size), v), + }, nil case float64: - return sqltypes.MakeTrusted(query.Type_FLOAT64, values.WriteFloat64(make([]byte, values.Float64Size), v)), nil + return Value{ + Typ: query.Type_FLOAT64, + Val: values.WriteFloat64(make([]byte, values.Float64Size), v), + }, nil case string: - return sqltypes.MakeTrusted(query.Type_VARCHAR, values.WriteString(make([]byte, len(v)), v, values.ByteOrderCollation)), nil + return Value{ + Typ: query.Type_VARCHAR, + Val: values.WriteString(make([]byte, len(v)), v, values.ByteOrderCollation), + }, nil case []byte: - return sqltypes.MakeTrusted(query.Type_BLOB, values.WriteBytes(make([]byte, len(v)), v, values.ByteOrderCollation)), nil + return Value{ + Typ: query.Type_BLOB, + Val: values.WriteBytes(make([]byte, len(v)), v, values.ByteOrderCollation), + }, nil default: - return sqltypes.Value{}, fmt.Errorf("type %T not implemented", v) + return Value{}, fmt.Errorf("type %T not implemented", v) } } diff --git a/sql/core.go b/sql/core.go index ee8d6e2f4d..c2996039eb 100644 --- a/sql/core.go +++ b/sql/core.go @@ -26,7 +26,6 @@ import ( "time" "unsafe" - "github.com/dolthub/vitess/go/sqltypes" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" @@ -465,7 +464,7 @@ func DebugString(nodeOrExpression interface{}) string { type Expression2 interface { Expression // Eval2 evaluates the given row frame and returns a result. - Eval2(ctx *Context, row Row2) (sqltypes.Value, error) + Eval2(ctx *Context, row Row2) (Value, error) // Type2 returns the expression type. Type2() Type2 IsExpr2() bool diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 6d7a67d790..5257dc4362 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -17,7 +17,6 @@ package expression import ( "fmt" - "github.com/dolthub/vitess/go/sqltypes" querypb "github.com/dolthub/vitess/go/vt/proto/query" errors "gopkg.in/src-d/go-errors.v1" @@ -524,7 +523,7 @@ func (gt *GreaterThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) return result == 1, nil } -func (gt *GreaterThan) Eval2(ctx *sql.Context, row sql.Row2) (sqltypes.Value, error) { +func (gt *GreaterThan) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { l, ok := gt.Left().(sql.Expression2) if !ok { panic(fmt.Sprintf("%T does not implement sql.Expression2", gt.Left())) @@ -536,28 +535,31 @@ func (gt *GreaterThan) Eval2(ctx *sql.Context, row sql.Row2) (sqltypes.Value, er lv, err := l.Eval2(ctx, row) if err != nil { - return sqltypes.Value{}, err + return sql.Value{}, err } rv, err := r.Eval2(ctx, row) if err != nil { - return sqltypes.Value{}, err + return sql.Value{}, err } // TODO: just assume they are int64 l64, err := types.ConvertValueToInt64(types.NumberTypeImpl_{}, lv) if err != nil { - return sqltypes.Value{}, err + return sql.Value{}, err } r64, err := types.ConvertValueToInt64(types.NumberTypeImpl_{}, rv) if err != nil { - return sqltypes.Value{}, err + return sql.Value{}, err } var rb byte if l64 > r64 { rb = 1 } - ret := sqltypes.MakeTrusted(querypb.Type_INT8, []byte{rb}) + ret := sql.Value{ + Val: []byte{rb}, + Typ: querypb.Type_INT8, + } return ret, nil } diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index 2611858867..319406e073 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -18,7 +18,6 @@ import ( "fmt" "strings" - "github.com/dolthub/vitess/go/sqltypes" errors "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" @@ -150,9 +149,9 @@ func (p *GetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return row[p.fieldIndex], nil } -func (p *GetField) Eval2(ctx *sql.Context, row sql.Row2) (sqltypes.Value, error) { +func (p *GetField) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { if p.fieldIndex < 0 || p.fieldIndex >= row.Len() { - return sqltypes.Value{}, ErrIndexOutOfBounds.New(p.fieldIndex, row.Len()) + return sql.Value{}, ErrIndexOutOfBounds.New(p.fieldIndex, row.Len()) } return row[p.fieldIndex], nil } diff --git a/sql/expression/literal.go b/sql/expression/literal.go index 977a9f5506..cc74bd7dc6 100644 --- a/sql/expression/literal.go +++ b/sql/expression/literal.go @@ -18,7 +18,6 @@ import ( "fmt" "strings" - "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/shopspring/decimal" @@ -31,7 +30,7 @@ import ( type Literal struct { Val interface{} Typ sql.Type - val2 sqltypes.Value + val2 sql.Value } var _ sql.Expression = &Literal{} @@ -137,7 +136,7 @@ func (*Literal) Children() []sql.Expression { return nil } -func (lit *Literal) Eval2(ctx *sql.Context, row sql.Row2) (sqltypes.Value, error) { +func (lit *Literal) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { return lit.val2, nil } diff --git a/sql/expression/unresolved.go b/sql/expression/unresolved.go index a0df5ab12f..c421699722 100644 --- a/sql/expression/unresolved.go +++ b/sql/expression/unresolved.go @@ -18,7 +18,6 @@ import ( "fmt" "strings" - "github.com/dolthub/vitess/go/sqltypes" "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" @@ -72,7 +71,7 @@ func (*UnresolvedColumn) CollationCoercibility(ctx *sql.Context) (collation sql. return sql.Collation_binary, 7 } -func (uc *UnresolvedColumn) Eval2(ctx *sql.Context, row sql.Row2) (sqltypes.Value, error) { +func (uc *UnresolvedColumn) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { panic("unresolved column is a placeholder node, but Eval2 was called") } diff --git a/sql/plan/filter.go b/sql/plan/filter.go index 0d0e4ebe39..79fa7d14e5 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -149,7 +149,7 @@ func (i *FilterIter) Next2(ctx *sql.Context) (sql.Row2, error) { if err != nil { return nil, err } - if res.Raw()[0] == 1 { + if res.Val[0] == 1 { return row, nil } } diff --git a/sql/row_frame.go b/sql/row_frame.go index a4384ec458..ebb79682e4 100644 --- a/sql/row_frame.go +++ b/sql/row_frame.go @@ -17,7 +17,6 @@ package sql import ( "sync" - "github.com/dolthub/vitess/go/sqltypes" querypb "github.com/dolthub/vitess/go/vt/proto/query" ) @@ -27,10 +26,10 @@ const ( ) // Row2 is a slice of values -type Row2 []sqltypes.Value +type Row2 []Value // GetField returns the Value for the ith field in this row. -func (r Row2) GetField(i int) sqltypes.Value { +func (r Row2) GetField(i int) Value { return r[i] } @@ -98,7 +97,10 @@ func (f *RowFrame) Row2() Row2 { rs := make(Row2, len(f.Values)) for i := range f.Values { - rs[i] = sqltypes.MakeTrusted(f.Types[i], f.Values[i]) + rs[i] = Value{ + Val: f.Values[i], + Typ: f.Types[i], + } } return rs } @@ -111,7 +113,10 @@ func (f *RowFrame) Row2Copy() Row2 { for i := range f.Values { v := make(ValueBytes, len(f.Values[i])) copy(v, f.Values[i]) - rs[i] = sqltypes.MakeTrusted(f.Types[i], v) + rs[i] = Value{ + Val: v, + Typ: f.Types[i], + } } return rs } diff --git a/sql/rows.go b/sql/rows.go index 2a969363bc..191147ad68 100644 --- a/sql/rows.go +++ b/sql/rows.go @@ -123,29 +123,29 @@ func RowFromRow2(sch Schema, r Row2) Row { for i, col := range sch { switch col.Type.Type() { case query.Type_INT8: - row[i] = values.ReadInt8(r.GetField(i).Raw()) + row[i] = values.ReadInt8(r.GetField(i).Val) case query.Type_UINT8: - row[i] = values.ReadUint8(r.GetField(i).Raw()) + row[i] = values.ReadUint8(r.GetField(i).Val) case query.Type_INT16: - row[i] = values.ReadInt16(r.GetField(i).Raw()) + row[i] = values.ReadInt16(r.GetField(i).Val) case query.Type_UINT16: - row[i] = values.ReadUint16(r.GetField(i).Raw()) + row[i] = values.ReadUint16(r.GetField(i).Val) case query.Type_INT32: - row[i] = values.ReadInt32(r.GetField(i).Raw()) + row[i] = values.ReadInt32(r.GetField(i).Val) case query.Type_UINT32: - row[i] = values.ReadUint32(r.GetField(i).Raw()) + row[i] = values.ReadUint32(r.GetField(i).Val) case query.Type_INT64: - row[i] = values.ReadInt64(r.GetField(i).Raw()) + row[i] = values.ReadInt64(r.GetField(i).Val) case query.Type_UINT64: - row[i] = values.ReadUint64(r.GetField(i).Raw()) + row[i] = values.ReadUint64(r.GetField(i).Val) case query.Type_FLOAT32: - row[i] = values.ReadFloat32(r.GetField(i).Raw()) + row[i] = values.ReadFloat32(r.GetField(i).Val) case query.Type_FLOAT64: - row[i] = values.ReadFloat64(r.GetField(i).Raw()) + row[i] = values.ReadFloat64(r.GetField(i).Val) case query.Type_TEXT, query.Type_VARCHAR, query.Type_CHAR: - row[i] = values.ReadString(r.GetField(i).Raw(), values.ByteOrderCollation) + row[i] = values.ReadString(r.GetField(i).Val, values.ByteOrderCollation) case query.Type_BLOB, query.Type_VARBINARY, query.Type_BINARY: - row[i] = values.ReadBytes(r.GetField(i).Raw(), values.ByteOrderCollation) + row[i] = values.ReadBytes(r.GetField(i).Val, values.ByteOrderCollation) case query.Type_BIT: fallthrough case query.Type_ENUM: diff --git a/sql/type.go b/sql/type.go index 285744a564..6d9f9adb01 100644 --- a/sql/type.go +++ b/sql/type.go @@ -295,11 +295,11 @@ func IsDecimalType(t Type) bool { type Type2 interface { Type // Compare2 returns an integer comparing two Values. - Compare2(sqltypes.Value, sqltypes.Value) (int, error) + Compare2(Value, Value) (int, error) // Convert2 converts a value of a compatible type. - Convert2(sqltypes.Value) (sqltypes.Value, error) + Convert2(Value) (Value, error) // Zero2 returns the zero Value for this type. - Zero2() sqltypes.Value + Zero2() Value } // SpatialColumnType is a node that contains a reference to all spatial types. diff --git a/sql/types/number.go b/sql/types/number.go index 1cdb3f0b4a..824ab2b33c 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -728,7 +728,7 @@ func (t NumberTypeImpl_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqlt return sqltypes.MakeTrusted(t.baseType, val), nil } -func (t NumberTypeImpl_) Compare2(a sqltypes.Value, b sqltypes.Value) (int, error) { +func (t NumberTypeImpl_) Compare2(a sql.Value, b sql.Value) (int, error) { switch t.baseType { case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint24, sqltypes.Uint32, sqltypes.Uint64: ca, err := convertValueToUint64(t, a) @@ -784,40 +784,84 @@ func (t NumberTypeImpl_) Compare2(a sqltypes.Value, b sqltypes.Value) (int, erro } } -func (t NumberTypeImpl_) Convert2(value sqltypes.Value) (sqltypes.Value, error) { +func (t NumberTypeImpl_) Convert2(value sql.Value) (sql.Value, error) { panic("implement me") } -func (t NumberTypeImpl_) Zero2() sqltypes.Value { +func (t NumberTypeImpl_) Zero2() sql.Value { switch t.baseType { case sqltypes.Int8: - return sqltypes.MakeTrusted(query.Type_INT8, make([]byte, values.Int8Size)) + x := values.WriteInt8(make([]byte, values.Int8Size), 0) + return sql.Value{ + Typ: query.Type_INT8, + Val: x, + } case sqltypes.Int16: - return sqltypes.MakeTrusted(query.Type_INT16, make([]byte, values.Int16Size)) + x := values.WriteInt16(make([]byte, values.Int16Size), 0) + return sql.Value{ + Typ: query.Type_INT16, + Val: x, + } case sqltypes.Int24: - return sqltypes.MakeTrusted(query.Type_INT24, make([]byte, values.Int24Size)) + x := values.WriteInt24(make([]byte, values.Int24Size), 0) + return sql.Value{ + Typ: query.Type_INT24, + Val: x, + } case sqltypes.Int32: - return sqltypes.MakeTrusted(query.Type_INT32, make([]byte, values.Int32Size)) + x := values.WriteInt32(make([]byte, values.Int32Size), 0) + return sql.Value{ + Typ: query.Type_INT32, + Val: x, + } case sqltypes.Int64: - return sqltypes.MakeTrusted(query.Type_INT64, make([]byte, values.Int64Size)) + x := values.WriteInt64(make([]byte, values.Int64Size), 0) + return sql.Value{ + Typ: query.Type_INT64, + Val: x, + } case sqltypes.Uint8: - return sqltypes.MakeTrusted(query.Type_UINT8, make([]byte, values.Uint8Size)) + x := values.WriteUint8(make([]byte, values.Uint8Size), 0) + return sql.Value{ + Typ: query.Type_UINT8, + Val: x, + } case sqltypes.Uint16: - return sqltypes.MakeTrusted(query.Type_UINT16, make([]byte, values.Uint16Size)) + x := values.WriteUint16(make([]byte, values.Uint16Size), 0) + return sql.Value{ + Typ: query.Type_UINT16, + Val: x, + } case sqltypes.Uint24: - return sqltypes.MakeTrusted(query.Type_UINT24, make([]byte, values.Uint24Size)) + x := values.WriteUint24(make([]byte, values.Uint24Size), 0) + return sql.Value{ + Typ: query.Type_UINT24, + Val: x, + } case sqltypes.Uint32: - return sqltypes.MakeTrusted(query.Type_UINT32, make([]byte, values.Uint32Size)) + x := values.WriteUint32(make([]byte, values.Uint32Size), 0) + return sql.Value{ + Typ: query.Type_UINT32, + Val: x, + } case sqltypes.Uint64: - return sqltypes.MakeTrusted(query.Type_UINT64, make([]byte, values.Uint64Size)) + x := values.WriteUint64(make([]byte, values.Uint64Size), 0) + return sql.Value{ + Typ: query.Type_UINT64, + Val: x, + } case sqltypes.Float32: - // TODO: 0 float32 is just 0? x := values.WriteFloat32(make([]byte, values.Float32Size), 0) - return sqltypes.MakeTrusted(query.Type_FLOAT32, x) + return sql.Value{ + Typ: query.Type_FLOAT32, + Val: x, + } case sqltypes.Float64: - // TODO: 0 float64 is just 0? - x := values.WriteFloat64(make([]byte, values.Float64Size), 0) - return sqltypes.MakeTrusted(query.Type_FLOAT64, x) + x := values.WriteUint64(make([]byte, values.Uint64Size), 0) + return sql.Value{ + Typ: query.Type_UINT64, + Val: x, + } default: panic(sql.ErrInvalidBaseType.New(t.baseType.String(), "number")) } @@ -1108,34 +1152,34 @@ func convertToInt64(t NumberTypeImpl_, v any, round Round) (int64, sql.ConvertIn } } -func ConvertValueToInt64(t NumberTypeImpl_, v sqltypes.Value) (int64, error) { - switch v.Type() { +func ConvertValueToInt64(t NumberTypeImpl_, v sql.Value) (int64, error) { + switch v.Typ { case query.Type_INT8: - return int64(values.ReadInt8(v.Raw())), nil + return int64(values.ReadInt8(v.Val)), nil case query.Type_INT16: - return int64(values.ReadInt16(v.Raw())), nil + return int64(values.ReadInt16(v.Val)), nil case query.Type_INT24: - return int64(values.ReadInt24(v.Raw())), nil + return int64(values.ReadInt24(v.Val)), nil case query.Type_INT32: - return int64(values.ReadInt32(v.Raw())), nil + return int64(values.ReadInt32(v.Val)), nil case query.Type_INT64: - return values.ReadInt64(v.Raw()), nil + return values.ReadInt64(v.Val), nil case query.Type_UINT8: - return int64(values.ReadUint8(v.Raw())), nil + return int64(values.ReadUint8(v.Val)), nil case query.Type_UINT16: - return int64(values.ReadUint16(v.Raw())), nil + return int64(values.ReadUint16(v.Val)), nil case query.Type_UINT24: - return int64(values.ReadUint24(v.Raw())), nil + return int64(values.ReadUint24(v.Val)), nil case query.Type_UINT32: - return int64(values.ReadUint32(v.Raw())), nil + return int64(values.ReadUint32(v.Val)), nil case query.Type_UINT64: - v := values.ReadUint64(v.Raw()) + v := values.ReadUint64(v.Val) if v > math.MaxInt64 { return math.MaxInt64, nil } return int64(v), nil case query.Type_FLOAT32: - v := values.ReadFloat32(v.Raw()) + v := values.ReadFloat32(v.Val) if v > float32(math.MaxInt64) { return math.MaxInt64, nil } else if v < float32(math.MinInt64) { @@ -1143,7 +1187,7 @@ func ConvertValueToInt64(t NumberTypeImpl_, v sqltypes.Value) (int64, error) { } return int64(math.Round(float64(v))), nil case query.Type_FLOAT64: - v := values.ReadFloat64(v.Raw()) + v := values.ReadFloat64(v.Val) if v > float64(math.MaxInt64) { return math.MaxInt64, nil } else if v < float64(math.MinInt64) { @@ -1156,36 +1200,36 @@ func ConvertValueToInt64(t NumberTypeImpl_, v sqltypes.Value) (int64, error) { } } -func convertValueToUint64(t NumberTypeImpl_, v sqltypes.Value) (uint64, error) { - switch v.Type() { +func convertValueToUint64(t NumberTypeImpl_, v sql.Value) (uint64, error) { + switch v.Typ { case query.Type_INT8: - return uint64(values.ReadInt8(v.Raw())), nil + return uint64(values.ReadInt8(v.Val)), nil case query.Type_INT16: - return uint64(values.ReadInt16(v.Raw())), nil + return uint64(values.ReadInt16(v.Val)), nil case query.Type_INT24: - return uint64(values.ReadInt24(v.Raw())), nil + return uint64(values.ReadInt24(v.Val)), nil case query.Type_INT32: - return uint64(values.ReadInt32(v.Raw())), nil + return uint64(values.ReadInt32(v.Val)), nil case query.Type_INT64: - return uint64(values.ReadInt64(v.Raw())), nil + return uint64(values.ReadInt64(v.Val)), nil case query.Type_UINT8: - return uint64(values.ReadUint8(v.Raw())), nil + return uint64(values.ReadUint8(v.Val)), nil case query.Type_UINT16: - return uint64(values.ReadUint16(v.Raw())), nil + return uint64(values.ReadUint16(v.Val)), nil case query.Type_UINT24: - return uint64(values.ReadUint24(v.Raw())), nil + return uint64(values.ReadUint24(v.Val)), nil case query.Type_UINT32: - return uint64(values.ReadUint32(v.Raw())), nil + return uint64(values.ReadUint32(v.Val)), nil case query.Type_UINT64: - return values.ReadUint64(v.Raw()), nil + return values.ReadUint64(v.Val), nil case query.Type_FLOAT32: - v := values.ReadFloat32(v.Raw()) + v := values.ReadFloat32(v.Val) if v >= float32(math.MaxUint64) { return math.MaxUint64, nil } return uint64(math.Round(float64(v))), nil case query.Type_FLOAT64: - v := values.ReadFloat64(v.Raw()) + v := values.ReadFloat64(v.Val) if v >= float64(math.MaxUint64) { return math.MaxUint64, nil } @@ -1384,32 +1428,32 @@ func convertToFloat64(t NumberTypeImpl_, v interface{}) (float64, error) { } } -func convertValueToFloat64(t NumberTypeImpl_, v sqltypes.Value) (float64, error) { - switch v.Type() { +func convertValueToFloat64(t NumberTypeImpl_, v sql.Value) (float64, error) { + switch v.Typ { case query.Type_INT8: - return float64(values.ReadInt8(v.Raw())), nil + return float64(values.ReadInt8(v.Val)), nil case query.Type_INT16: - return float64(values.ReadInt16(v.Raw())), nil + return float64(values.ReadInt16(v.Val)), nil case query.Type_INT24: - return float64(values.ReadInt24(v.Raw())), nil + return float64(values.ReadInt24(v.Val)), nil case query.Type_INT32: - return float64(values.ReadInt32(v.Raw())), nil + return float64(values.ReadInt32(v.Val)), nil case query.Type_INT64: - return float64(values.ReadInt64(v.Raw())), nil + return float64(values.ReadInt64(v.Val)), nil case query.Type_UINT8: - return float64(values.ReadUint8(v.Raw())), nil + return float64(values.ReadUint8(v.Val)), nil case query.Type_UINT16: - return float64(values.ReadUint16(v.Raw())), nil + return float64(values.ReadUint16(v.Val)), nil case query.Type_UINT24: - return float64(values.ReadUint24(v.Raw())), nil + return float64(values.ReadUint24(v.Val)), nil case query.Type_UINT32: - return float64(values.ReadUint32(v.Raw())), nil + return float64(values.ReadUint32(v.Val)), nil case query.Type_UINT64: - return float64(values.ReadUint64(v.Raw())), nil + return float64(values.ReadUint64(v.Val)), nil case query.Type_FLOAT32: - return float64(values.ReadFloat32(v.Raw())), nil + return float64(values.ReadFloat32(v.Val)), nil case query.Type_FLOAT64: - return values.ReadFloat64(v.Raw()), nil + return values.ReadFloat64(v.Val), nil default: panic(sql.ErrInvalidBaseType.New(t.baseType.String(), "number")) } From baf2d775420e7a34427c2ea9483312371ff80bfa Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 16 Oct 2025 17:00:27 -0700 Subject: [PATCH 16/59] include Wrapper values for out of band values --- server/handler.go | 17 ++++++++++++++++- sql/row_frame.go | 9 +++++---- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/server/handler.go b/server/handler.go index 1ac48b7c7f..28bd778516 100644 --- a/server/handler.go +++ b/server/handler.go @@ -870,7 +870,22 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq } resRow := make([]sqltypes.Value, len(row)) for i, v := range row { - resRow[i] = sqltypes.MakeTrusted(v.Typ, v.Val) + if v.Val != nil || v.Val2 == nil { + resRow[i] = sqltypes.MakeTrusted(v.Typ, v.Val) + continue + } + dVal, err := v.Val2.UnwrapAny(ctx) + if err != nil { + return err + } + switch dVal := dVal.(type) { + case []byte: + resRow[i] = sqltypes.MakeTrusted(v.Typ, dVal) + case string: + resRow[i] = sqltypes.MakeTrusted(v.Typ, []byte(dVal)) + default: + panic(fmt.Sprintf("unexpected type %T", dVal)) + } } ctx.GetLogger().Tracef("spooling result row %s", resRow) res.Rows = append(res.Rows, resRow) diff --git a/sql/row_frame.go b/sql/row_frame.go index ebb79682e4..0e0ad7f23c 100644 --- a/sql/row_frame.go +++ b/sql/row_frame.go @@ -38,10 +38,13 @@ func (r Row2) Len() int { return len(r) } +type ValueBytes []byte + // Value is a logical index into a Row2. For efficiency reasons, use sparingly. type Value struct { - Val ValueBytes - Typ querypb.Type + Val ValueBytes + Val2 AnyWrapper + Typ querypb.Type // TODO: consider sqltypes.Type instead } // IsNull returns whether this value represents NULL @@ -49,8 +52,6 @@ func (v Value) IsNull() bool { return v.Val == nil || v.Typ == querypb.Type_NULL_TYPE } -type ValueBytes []byte - type RowFrame struct { Types []querypb.Type From f9dbd77f20dcfb4a7b24f67c314646e0a0ff0e89 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 17 Oct 2025 10:40:50 -0700 Subject: [PATCH 17/59] rename to rowvalue --- server/handler.go | 6 +-- sql/cache.go | 6 +-- sql/core.go | 2 +- sql/expression/comparison.go | 2 +- sql/expression/get_field.go | 6 +-- sql/expression/literal.go | 2 +- sql/expression/sort.go | 4 +- sql/expression/unresolved.go | 2 +- sql/memory.go | 4 +- sql/plan/filter.go | 2 +- sql/plan/indexed_table_access.go | 4 +- sql/plan/process.go | 2 +- sql/row_frame.go | 36 ++++++---------- sql/rowexec/transaction_iters.go | 2 +- sql/rows.go | 71 +------------------------------- sql/table_iter.go | 2 +- 16 files changed, 37 insertions(+), 116 deletions(-) diff --git a/server/handler.go b/server/handler.go index 28bd778516..1b71394a17 100644 --- a/server/handler.go +++ b/server/handler.go @@ -807,7 +807,7 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq // TODO: send results instead of rows? // Read rows from iter and send them off - var rowChan = make(chan sql.Row2, 512) + var rowChan = make(chan sql.ValueRow, 512) eg.Go(func() (err error) { defer pan2err(&err) defer wg.Done() @@ -870,11 +870,11 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq } resRow := make([]sqltypes.Value, len(row)) for i, v := range row { - if v.Val != nil || v.Val2 == nil { + if v.Val != nil || v.WrappedVal == nil { resRow[i] = sqltypes.MakeTrusted(v.Typ, v.Val) continue } - dVal, err := v.Val2.UnwrapAny(ctx) + dVal, err := v.WrappedVal.UnwrapAny(ctx) if err != nil { return err } diff --git a/sql/cache.go b/sql/cache.go index d664b2c957..b2984eac24 100644 --- a/sql/cache.go +++ b/sql/cache.go @@ -74,7 +74,7 @@ type rowsCache struct { memory Freeable reporter Reporter rows []Row - rows2 []Row2 + rows2 []ValueRow } func newRowsCache(memory Freeable, r Reporter) *rowsCache { @@ -92,7 +92,7 @@ func (c *rowsCache) Add(row Row) error { func (c *rowsCache) Get() []Row { return c.rows } -func (c *rowsCache) Add2(row2 Row2) error { +func (c *rowsCache) Add2(row2 ValueRow) error { if !releaseMemoryIfNeeded(c.reporter, c.memory.Free) { return ErrNoMemoryAvailable.New() } @@ -101,7 +101,7 @@ func (c *rowsCache) Add2(row2 Row2) error { return nil } -func (c *rowsCache) Get2() []Row2 { +func (c *rowsCache) Get2() []ValueRow { return c.rows2 } diff --git a/sql/core.go b/sql/core.go index c2996039eb..37c32f1672 100644 --- a/sql/core.go +++ b/sql/core.go @@ -464,7 +464,7 @@ func DebugString(nodeOrExpression interface{}) string { type Expression2 interface { Expression // Eval2 evaluates the given row frame and returns a result. - Eval2(ctx *Context, row Row2) (Value, error) + Eval2(ctx *Context, row ValueRow) (Value, error) // Type2 returns the expression type. Type2() Type2 IsExpr2() bool diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 5257dc4362..0e44926350 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -523,7 +523,7 @@ func (gt *GreaterThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) return result == 1, nil } -func (gt *GreaterThan) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { +func (gt *GreaterThan) Eval2(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { l, ok := gt.Left().(sql.Expression2) if !ok { panic(fmt.Sprintf("%T does not implement sql.Expression2", gt.Left())) diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index 319406e073..e660c4fcb8 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -149,9 +149,9 @@ func (p *GetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return row[p.fieldIndex], nil } -func (p *GetField) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { - if p.fieldIndex < 0 || p.fieldIndex >= row.Len() { - return sql.Value{}, ErrIndexOutOfBounds.New(p.fieldIndex, row.Len()) +func (p *GetField) Eval2(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { + if p.fieldIndex < 0 || p.fieldIndex >= len(row) { + return sql.Value{}, ErrIndexOutOfBounds.New(p.fieldIndex, len(row)) } return row[p.fieldIndex], nil } diff --git a/sql/expression/literal.go b/sql/expression/literal.go index cc74bd7dc6..4d3f9334c7 100644 --- a/sql/expression/literal.go +++ b/sql/expression/literal.go @@ -136,7 +136,7 @@ func (*Literal) Children() []sql.Expression { return nil } -func (lit *Literal) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { +func (lit *Literal) Eval2(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { return lit.val2, nil } diff --git a/sql/expression/sort.go b/sql/expression/sort.go index d54d13ea77..164ef4f279 100644 --- a/sql/expression/sort.go +++ b/sql/expression/sort.go @@ -86,12 +86,12 @@ func (s *Sorter) Less(i, j int) bool { return false } -// Sorter2 is a version of Sorter that operates on Row2 +// Sorter2 is a version of Sorter that operates on ValueRow type Sorter2 struct { LastError error Ctx *sql.Context SortFields []sql.SortField - Rows []sql.Row2 + Rows []sql.ValueRow } func (s *Sorter2) Len() int { diff --git a/sql/expression/unresolved.go b/sql/expression/unresolved.go index c421699722..40e4799d75 100644 --- a/sql/expression/unresolved.go +++ b/sql/expression/unresolved.go @@ -71,7 +71,7 @@ func (*UnresolvedColumn) CollationCoercibility(ctx *sql.Context) (collation sql. return sql.Collation_binary, 7 } -func (uc *UnresolvedColumn) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { +func (uc *UnresolvedColumn) Eval2(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { panic("unresolved column is a placeholder node, but Eval2 was called") } diff --git a/sql/memory.go b/sql/memory.go index 651ed9e68f..d3ad793942 100644 --- a/sql/memory.go +++ b/sql/memory.go @@ -70,9 +70,9 @@ type Rows2Cache interface { // Add2 a new row to the cache. If there is no memory available, it will try to // free some memory. If after that there is still no memory available, it // will return an error and erase all the content of the cache. - Add2(Row2) error + Add2(ValueRow) error // Get2 gets all rows. - Get2() []Row2 + Get2() []ValueRow } // ErrNoMemoryAvailable is returned when there is no more available memory. diff --git a/sql/plan/filter.go b/sql/plan/filter.go index 79fa7d14e5..4dfff0e359 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -139,7 +139,7 @@ func (i *FilterIter) Next(ctx *sql.Context) (sql.Row, error) { } } -func (i *FilterIter) Next2(ctx *sql.Context) (sql.Row2, error) { +func (i *FilterIter) Next2(ctx *sql.Context) (sql.ValueRow, error) { for { row, err := i.childIter2.Next2(ctx) if err != nil { diff --git a/sql/plan/indexed_table_access.go b/sql/plan/indexed_table_access.go index 7a708986b7..ee768c8a1a 100644 --- a/sql/plan/indexed_table_access.go +++ b/sql/plan/indexed_table_access.go @@ -307,7 +307,7 @@ func (i *IndexedTableAccess) GetLookup(ctx *sql.Context, row sql.Row) (sql.Index return i.lb.GetLookup(ctx, key) } -func (i *IndexedTableAccess) getLookup2(ctx *sql.Context, row sql.Row2) (sql.IndexLookup, error) { +func (i *IndexedTableAccess) getLookup2(ctx *sql.Context, row sql.ValueRow) (sql.IndexLookup, error) { // if the lookup was provided at analysis time (static evaluation), use it. if !i.lookup.IsEmpty() { return i.lookup, nil @@ -636,7 +636,7 @@ func (lb *LookupBuilder) GetKey(ctx *sql.Context, row sql.Row) (lookupBuilderKey return lb.key, nil } -func (lb *LookupBuilder) GetKey2(ctx *sql.Context, row sql.Row2) (lookupBuilderKey, error) { +func (lb *LookupBuilder) GetKey2(ctx *sql.Context, row sql.ValueRow) (lookupBuilderKey, error) { if lb.key == nil { lb.key = make([]interface{}, len(lb.keyExprs)) } diff --git a/sql/plan/process.go b/sql/plan/process.go index 70a687247f..e9974bbd72 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -318,7 +318,7 @@ func (i *TrackedRowIter) Next(ctx *sql.Context) (sql.Row, error) { return row, nil } -func (i *TrackedRowIter) Next2(ctx *sql.Context) (sql.Row2, error) { +func (i *TrackedRowIter) Next2(ctx *sql.Context) (sql.ValueRow, error) { row, err := i.iter2.Next2(ctx) if err != nil { return nil, err diff --git a/sql/row_frame.go b/sql/row_frame.go index 0e0ad7f23c..1c659b59f4 100644 --- a/sql/row_frame.go +++ b/sql/row_frame.go @@ -25,31 +25,21 @@ const ( fieldArrSize = 2048 ) -// Row2 is a slice of values -type Row2 []Value - -// GetField returns the Value for the ith field in this row. -func (r Row2) GetField(i int) Value { - return r[i] -} - -// Len returns the number of fields of this row -func (r Row2) Len() int { - return len(r) -} - type ValueBytes []byte -// Value is a logical index into a Row2. For efficiency reasons, use sparingly. +// Value is a logical index into a ValueRow. For efficiency reasons, use sparingly. type Value struct { - Val ValueBytes - Val2 AnyWrapper - Typ querypb.Type // TODO: consider sqltypes.Type instead + Val ValueBytes + WrappedVal AnyWrapper + Typ querypb.Type // TODO: consider sqltypes.Type instead } +// ValueRow is a slice of values +type ValueRow []Value + // IsNull returns whether this value represents NULL func (v Value) IsNull() bool { - return v.Val == nil || v.Typ == querypb.Type_NULL_TYPE + return (v.Val == nil && v.WrappedVal == nil) || v.Typ == querypb.Type_NULL_TYPE } type RowFrame struct { @@ -89,14 +79,14 @@ func (f *RowFrame) Recycle() { framePool.Put(f) } -// Row2 returns the underlying row value in this frame. Does not make a deep copy of underlying byte arrays, so +// ValueRow returns the underlying row value in this frame. Does not make a deep copy of underlying byte arrays, so // further modification to this frame may result in the returned value changing as well. -func (f *RowFrame) Row2() Row2 { +func (f *RowFrame) Row2() ValueRow { if f == nil { return nil } - rs := make(Row2, len(f.Values)) + rs := make(ValueRow, len(f.Values)) for i := range f.Values { rs[i] = Value{ Val: f.Values[i], @@ -108,8 +98,8 @@ func (f *RowFrame) Row2() Row2 { // Row2Copy returns the row in this frame as a deep copy of the underlying byte arrays. Useful when reusing the // RowFrame object via Clear() -func (f *RowFrame) Row2Copy() Row2 { - rs := make(Row2, len(f.Values)) +func (f *RowFrame) Row2Copy() ValueRow { + rs := make(ValueRow, len(f.Values)) // TODO: it would be faster here to just copy the entire value backing array in one pass for i := range f.Values { v := make(ValueBytes, len(f.Values[i])) diff --git a/sql/rowexec/transaction_iters.go b/sql/rowexec/transaction_iters.go index f0f56168ef..f2f466d281 100644 --- a/sql/rowexec/transaction_iters.go +++ b/sql/rowexec/transaction_iters.go @@ -100,7 +100,7 @@ func (t *TransactionCommittingIter) Next(ctx *sql.Context) (sql.Row, error) { return t.childIter.Next(ctx) } -func (t *TransactionCommittingIter) Next2(ctx *sql.Context) (sql.Row2, error) { +func (t *TransactionCommittingIter) Next2(ctx *sql.Context) (sql.ValueRow, error) { return t.childIter2.Next2(ctx) } diff --git a/sql/rows.go b/sql/rows.go index 191147ad68..9f8ef3d1c5 100644 --- a/sql/rows.go +++ b/sql/rows.go @@ -18,10 +18,6 @@ import ( "fmt" "io" "strings" - - "github.com/dolthub/vitess/go/vt/proto/query" - - "github.com/dolthub/go-mysql-server/sql/values" ) // Row is a tuple of values. @@ -94,7 +90,7 @@ type RowIter interface { type RowIter2 interface { RowIter - Next2(ctx *Context) (Row2, error) + Next2(ctx *Context) (ValueRow, error) IsRowIter2(ctx *Context) bool } @@ -118,71 +114,6 @@ func RowIterToRows(ctx *Context, i RowIter) ([]Row, error) { return rows, i.Close(ctx) } -func RowFromRow2(sch Schema, r Row2) Row { - row := make(Row, len(sch)) - for i, col := range sch { - switch col.Type.Type() { - case query.Type_INT8: - row[i] = values.ReadInt8(r.GetField(i).Val) - case query.Type_UINT8: - row[i] = values.ReadUint8(r.GetField(i).Val) - case query.Type_INT16: - row[i] = values.ReadInt16(r.GetField(i).Val) - case query.Type_UINT16: - row[i] = values.ReadUint16(r.GetField(i).Val) - case query.Type_INT32: - row[i] = values.ReadInt32(r.GetField(i).Val) - case query.Type_UINT32: - row[i] = values.ReadUint32(r.GetField(i).Val) - case query.Type_INT64: - row[i] = values.ReadInt64(r.GetField(i).Val) - case query.Type_UINT64: - row[i] = values.ReadUint64(r.GetField(i).Val) - case query.Type_FLOAT32: - row[i] = values.ReadFloat32(r.GetField(i).Val) - case query.Type_FLOAT64: - row[i] = values.ReadFloat64(r.GetField(i).Val) - case query.Type_TEXT, query.Type_VARCHAR, query.Type_CHAR: - row[i] = values.ReadString(r.GetField(i).Val, values.ByteOrderCollation) - case query.Type_BLOB, query.Type_VARBINARY, query.Type_BINARY: - row[i] = values.ReadBytes(r.GetField(i).Val, values.ByteOrderCollation) - case query.Type_BIT: - fallthrough - case query.Type_ENUM: - fallthrough - case query.Type_SET: - fallthrough - case query.Type_TUPLE: - fallthrough - case query.Type_GEOMETRY: - fallthrough - case query.Type_JSON: - fallthrough - case query.Type_EXPRESSION: - fallthrough - case query.Type_INT24: - fallthrough - case query.Type_UINT24: - fallthrough - case query.Type_TIMESTAMP: - fallthrough - case query.Type_DATE: - fallthrough - case query.Type_TIME: - fallthrough - case query.Type_DATETIME: - fallthrough - case query.Type_YEAR: - fallthrough - case query.Type_DECIMAL: - panic(fmt.Sprintf("Unimplemented type conversion: %T", col.Type)) - default: - panic(fmt.Sprintf("unknown type %T", col.Type)) - } - } - return row -} - // RowsToRowIter creates a RowIter that iterates over the given rows. func RowsToRowIter(rows ...Row) RowIter { return &sliceRowIter{rows: rows} diff --git a/sql/table_iter.go b/sql/table_iter.go index 884778307a..8f4b12fa30 100644 --- a/sql/table_iter.go +++ b/sql/table_iter.go @@ -79,7 +79,7 @@ func (i *TableRowIter) Next(ctx *Context) (Row, error) { return row, err } -func (i *TableRowIter) Next2(ctx *Context) (Row2, error) { +func (i *TableRowIter) Next2(ctx *Context) (ValueRow, error) { select { case <-ctx.Done(): return nil, ctx.Err() From 8262f83413a78ac9bd486a7ee4870277f204ceef Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 20 Oct 2025 15:42:51 -0700 Subject: [PATCH 18/59] refactoring and fixing tests --- server/handler.go | 57 ++++++++++++++++++++++++++---------------- sql/expression/sort.go | 50 ------------------------------------ sql/row_frame.go | 2 +- sql/type.go | 7 +----- sql/types/bit.go | 13 ++++++++++ sql/types/datetime.go | 26 +++++++++++++++++++ sql/types/number.go | 31 +++++++++++------------ sql/types/strings.go | 20 +++++++++++++++ sql/types/time.go | 11 ++++++++ sql/types/year.go | 10 ++++++++ sql/values/encoding.go | 7 +++--- 11 files changed, 136 insertions(+), 98 deletions(-) diff --git a/server/handler.go b/server/handler.go index 1b71394a17..4337b01411 100644 --- a/server/handler.go +++ b/server/handler.go @@ -496,7 +496,7 @@ func (h *Handler) doQuery( } else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) { r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf) } else if ri2, ok := rowIter.(sql.RowIter2); ok && ri2.IsRowIter2(sqlCtx) { - r, processedAtLeastOneBatch, err = h.resultForDefaultIter2(sqlCtx, c, ri2, resultFields, callback, more) + r, processedAtLeastOneBatch, err = h.resultForDefaultIter2(sqlCtx, c, schema, ri2, resultFields, buf, callback, more) } else { r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf) } @@ -770,14 +770,13 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s return r, processedAtLeastOneBatch, nil } -func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sql.RowIter2, resultFields []*querypb.Field, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, bool, error) { +func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, schema sql.Schema, iter sql.RowIter2, resultFields []*querypb.Field, buf *sql.ByteBuffer, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, bool, error) { defer trace.StartRegion(ctx, "Handler.resultForDefaultIter2").End() eg, ctx := ctx.NewErrgroup() pan2err := func(err *error) { if recoveredPanic := recover(); recoveredPanic != nil { - stack := debug.Stack() - wrappedErr := fmt.Errorf("handler caught panic: %v\n%s", recoveredPanic, stack) + wrappedErr := fmt.Errorf("handler caught panic: %v\n%s", recoveredPanic, debug.Stack()) *err = goerrors.Join(*err, wrappedErr) } } @@ -868,24 +867,9 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq if !ok { return nil } - resRow := make([]sqltypes.Value, len(row)) - for i, v := range row { - if v.Val != nil || v.WrappedVal == nil { - resRow[i] = sqltypes.MakeTrusted(v.Typ, v.Val) - continue - } - dVal, err := v.WrappedVal.UnwrapAny(ctx) - if err != nil { - return err - } - switch dVal := dVal.(type) { - case []byte: - resRow[i] = sqltypes.MakeTrusted(v.Typ, dVal) - case string: - resRow[i] = sqltypes.MakeTrusted(v.Typ, []byte(dVal)) - default: - panic(fmt.Sprintf("unexpected type %T", dVal)) - } + resRow, err := RowValueToSQLValues(ctx, schema, row, buf) + if err != nil { + return err } ctx.GetLogger().Tracef("spooling result row %s", resRow) res.Rows = append(res.Rows, resRow) @@ -1187,6 +1171,35 @@ func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Express return outVals, nil } +func RowValueToSQLValues(ctx *sql.Context, sch sql.Schema, row sql.ValueRow, buf *sql.ByteBuffer) ([]sqltypes.Value, error) { + if len(sch) == 0 { + return []sqltypes.Value{}, nil + } + var err error + outVals := make([]sqltypes.Value, len(sch)) + for i, col := range sch { + // TODO: remove this check once all Types implement this + valType, ok := col.Type.(sql.Type2) + if !ok { + outVals[i] = sqltypes.MakeTrusted(row[i].Typ, row[i].Val) + continue + } + if buf == nil { + outVals[i], err = valType.ToSQLValue(ctx, row[i], nil) + if err != nil { + return nil, err + } + continue + } + outVals[i], err = valType.ToSQLValue(ctx, row[i], buf.Get()) + if err != nil { + return nil, err + } + buf.Grow(outVals[i].Len()) + } + return outVals, nil +} + func schemaToFields(ctx *sql.Context, s sql.Schema) []*querypb.Field { charSetResults := ctx.GetCharacterSetResults() fields := make([]*querypb.Field, len(s)) diff --git a/sql/expression/sort.go b/sql/expression/sort.go index 164ef4f279..ab283aad98 100644 --- a/sql/expression/sort.go +++ b/sql/expression/sort.go @@ -102,56 +102,6 @@ func (s *Sorter2) Swap(i, j int) { s.Rows[i], s.Rows[j] = s.Rows[j], s.Rows[i] } -func (s *Sorter2) Less(i, j int) bool { - if s.LastError != nil { - return false - } - - a := s.Rows[i] - b := s.Rows[j] - for _, sf := range s.SortFields { - typ := sf.Column2.Type2() - av, err := sf.Column2.Eval2(s.Ctx, a) - if err != nil { - s.LastError = sql.ErrUnableSort.Wrap(err) - return false - } - - bv, err := sf.Column2.Eval2(s.Ctx, b) - if err != nil { - s.LastError = sql.ErrUnableSort.Wrap(err) - return false - } - - if sf.Order == sql.Descending { - av, bv = bv, av - } - - if av.IsNull() && bv.IsNull() { - continue - } else if av.IsNull() { - return sf.NullOrdering == sql.NullsFirst - } else if bv.IsNull() { - return sf.NullOrdering != sql.NullsFirst - } - - cmp, err := typ.Compare2(av, bv) - if err != nil { - s.LastError = err - return false - } - - switch cmp { - case -1: - return true - case 1: - return false - } - } - - return false -} - // TopRowsHeap implements heap.Interface based on Sorter. It inverts the Less() // function so that it can be used to implement TopN. heap.Push() rows into it, // and if Len() > MAX; heap.Pop() the current min row. Then, at the end of diff --git a/sql/row_frame.go b/sql/row_frame.go index 1c659b59f4..067c24c400 100644 --- a/sql/row_frame.go +++ b/sql/row_frame.go @@ -30,7 +30,7 @@ type ValueBytes []byte // Value is a logical index into a ValueRow. For efficiency reasons, use sparingly. type Value struct { Val ValueBytes - WrappedVal AnyWrapper + WrappedVal BytesWrapper Typ querypb.Type // TODO: consider sqltypes.Type instead } diff --git a/sql/type.go b/sql/type.go index 6d9f9adb01..da49d1b1f2 100644 --- a/sql/type.go +++ b/sql/type.go @@ -294,12 +294,7 @@ func IsDecimalType(t Type) bool { type Type2 interface { Type - // Compare2 returns an integer comparing two Values. - Compare2(Value, Value) (int, error) - // Convert2 converts a value of a compatible type. - Convert2(Value) (Value, error) - // Zero2 returns the zero Value for this type. - Zero2() Value + ToSQLValue(*Context, Value, []byte) (sqltypes.Value, error) } // SpatialColumnType is a node that contains a reference to all spatial types. diff --git a/sql/types/bit.go b/sql/types/bit.go index 7f9ef77d95..68227c4f8e 100644 --- a/sql/types/bit.go +++ b/sql/types/bit.go @@ -18,7 +18,9 @@ import ( "context" "encoding/binary" "fmt" + "github.com/dolthub/go-mysql-server/sql/values" "reflect" + "strconv" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" @@ -211,6 +213,17 @@ func (t BitType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va return sqltypes.MakeTrusted(sqltypes.Bit, val), nil } +// ToSQLValue implements Type2 interface. +func (t BitType_) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { + if v.IsNull() { + return sqltypes.NULL, nil + } + // Assume this is uint64 + x := values.ReadUint64(v.Val) + dest = strconv.AppendUint(dest, x, 10) + return sqltypes.MakeTrusted(sqltypes.Bit, dest), nil +} + // String implements Type interface. func (t BitType_) String() string { return fmt.Sprintf("bit(%v)", t.numOfBits) diff --git a/sql/types/datetime.go b/sql/types/datetime.go index 387956f0fb..38eea46e3a 100644 --- a/sql/types/datetime.go +++ b/sql/types/datetime.go @@ -17,6 +17,7 @@ package types import ( "context" "fmt" + "github.com/dolthub/go-mysql-server/sql/values" "math" "reflect" "time" @@ -474,6 +475,31 @@ func (t datetimeType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltype return sqltypes.MakeTrusted(typ, valBytes), nil } +func (t datetimeType) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { + if v.IsNull() { + return sqltypes.NULL, nil + } + switch t.baseType { + case sqltypes.Date: + // TODO: move this to values + x := values.ReadUint32(v.Val) + y := x >> 16 + m := (x & (255 << 8)) >> 8 + d := x & 255 + t := time.Date(int(y), time.Month(m), int(d), 0, 0, 0, 0, time.UTC) + dest = t.AppendFormat(dest, sql.TimestampDatetimeLayout) + + case sqltypes.Datetime, sqltypes.Timestamp: + x := values.ReadInt64(v.Val) + t := time.UnixMicro(x).UTC() + dest = t.AppendFormat(dest, sql.TimestampDatetimeLayout) + + default: + return sqltypes.Value{}, sql.ErrInvalidBaseType.New(t.baseType.String(), "datetime") + } + return sqltypes.MakeTrusted(t.baseType, dest), nil +} + func (t datetimeType) String() string { switch t.baseType { case sqltypes.Date: diff --git a/sql/types/number.go b/sql/types/number.go index 824ab2b33c..14fbe2a4be 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -867,55 +867,54 @@ func (t NumberTypeImpl_) Zero2() sql.Value { } } -// SQL2 implements Type2 interface. -func (t NumberTypeImpl_) SQL2(v sql.Value) (sqltypes.Value, error) { +// ToSQLValue implements Type2 interface. +func (t NumberTypeImpl_) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { if v.IsNull() { return sqltypes.NULL, nil } - var val []byte switch t.baseType { case sqltypes.Int8: x := values.ReadInt8(v.Val) - val = []byte(strconv.FormatInt(int64(x), 10)) + dest = strconv.AppendInt(dest, int64(x), 10) case sqltypes.Int16: x := values.ReadInt16(v.Val) - val = []byte(strconv.FormatInt(int64(x), 10)) + dest = strconv.AppendInt(dest, int64(x), 10) case sqltypes.Int24: x := values.ReadInt24(v.Val) - val = []byte(strconv.FormatInt(int64(x), 10)) + dest = strconv.AppendInt(dest, int64(x), 10) case sqltypes.Int32: x := values.ReadInt32(v.Val) - val = []byte(strconv.FormatInt(int64(x), 10)) + dest = strconv.AppendInt(dest, int64(x), 10) case sqltypes.Int64: x := values.ReadInt64(v.Val) - val = []byte(strconv.FormatInt(x, 10)) + dest = strconv.AppendInt(dest, x, 10) case sqltypes.Uint8: x := values.ReadUint8(v.Val) - val = []byte(strconv.FormatUint(uint64(x), 10)) + dest = strconv.AppendUint(dest, uint64(x), 10) case sqltypes.Uint16: x := values.ReadUint16(v.Val) - val = []byte(strconv.FormatUint(uint64(x), 10)) + dest = strconv.AppendUint(dest, uint64(x), 10) case sqltypes.Uint24: x := values.ReadUint24(v.Val) - val = []byte(strconv.FormatUint(uint64(x), 10)) + dest = strconv.AppendUint(dest, uint64(x), 10) case sqltypes.Uint32: x := values.ReadUint32(v.Val) - val = []byte(strconv.FormatUint(uint64(x), 10)) + dest = strconv.AppendUint(dest, uint64(x), 10) case sqltypes.Uint64: x := values.ReadUint64(v.Val) - val = []byte(strconv.FormatUint(x, 10)) + dest = strconv.AppendUint(dest, x, 10) case sqltypes.Float32: x := values.ReadFloat32(v.Val) - val = []byte(strconv.FormatFloat(float64(x), 'f', -1, 32)) + dest = strconv.AppendFloat(dest, float64(x), 'f', -1, 32) case sqltypes.Float64: x := values.ReadFloat64(v.Val) - val = []byte(strconv.FormatFloat(x, 'f', -1, 64)) + dest = strconv.AppendFloat(dest, x, 'f', -1, 64) default: panic(sql.ErrInvalidBaseType.New(t.baseType.String(), "number")) } - return sqltypes.MakeTrusted(t.baseType, val), nil + return sqltypes.MakeTrusted(t.baseType, dest), nil } // String implements Type interface. diff --git a/sql/types/strings.go b/sql/types/strings.go index e44e3d5a54..db66946aa7 100644 --- a/sql/types/strings.go +++ b/sql/types/strings.go @@ -790,6 +790,26 @@ func (t StringType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes. return sqltypes.MakeTrusted(t.baseType, val), nil } +// ToSQLValue implements ValueType interface. +func (t StringType) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { + if v.IsNull() { + return sqltypes.NULL, nil + } + + // TODO: collations + // TODO: deal with casting numbers? + // No need to use dest buffer as we have already allocated []byte + var err error + if v.Val == nil && v.WrappedVal != nil { + v.Val, err = v.WrappedVal.Unwrap(ctx) + if err != nil { + return sqltypes.Value{}, err + } + } + + return sqltypes.MakeTrusted(t.baseType, v.Val), nil +} + // String implements Type interface. func (t StringType) String() string { return t.StringWithTableCollation(sql.Collation_Default) diff --git a/sql/types/time.go b/sql/types/time.go index b8ca8b005e..0adf802c90 100644 --- a/sql/types/time.go +++ b/sql/types/time.go @@ -16,6 +16,7 @@ package types import ( "context" + "github.com/dolthub/go-mysql-server/sql/values" "math" "reflect" "strconv" @@ -267,6 +268,16 @@ func (t TimespanType_) SQL(_ *sql.Context, dest []byte, v interface{}) (sqltypes return sqltypes.MakeTrusted(sqltypes.Time, val), nil } +func (t TimespanType_) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { + if v.IsNull() { + return sqltypes.NULL, nil + } + x := values.ReadInt64(v.Val) + // TODO: write version of this that takes advantage of dest + v.Val = Timespan(x).Bytes() + return sqltypes.MakeTrusted(sqltypes.Time, v.Val), nil +} + // String implements Type interface. func (t TimespanType_) String() string { return "time(6)" diff --git a/sql/types/year.go b/sql/types/year.go index c1e1ddc5ff..2362408035 100644 --- a/sql/types/year.go +++ b/sql/types/year.go @@ -16,6 +16,7 @@ package types import ( "context" + "github.com/dolthub/go-mysql-server/sql/values" "reflect" "strconv" "time" @@ -171,6 +172,15 @@ func (t YearType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.V return sqltypes.MakeTrusted(sqltypes.Year, val), nil } +func (t YearType_) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { + if v.IsNull() { + return sqltypes.NULL, nil + } + x := values.ReadUint8(v.Val) + dest = strconv.AppendInt(dest, int64(x), 10) + return sqltypes.MakeTrusted(sqltypes.Year, dest), nil +} + // String implements Type interface. func (t YearType_) String() string { return "year" diff --git a/sql/values/encoding.go b/sql/values/encoding.go index d00e630091..45e9445b18 100644 --- a/sql/values/encoding.go +++ b/sql/values/encoding.go @@ -17,6 +17,7 @@ package values import ( "bytes" "encoding/binary" + "fmt" "math" ) @@ -129,13 +130,13 @@ func ReadUint16(val []byte) uint16 { } func ReadInt24(val []byte) (i int32) { - expectSize(val, Int24Size) + expectSize(val, Int32Size) i = int32(binary.LittleEndian.Uint32([]byte{0, val[0], val[1], val[2]})) return } func ReadUint24(val []byte) (u uint32) { - expectSize(val, Int24Size) + expectSize(val, Int32Size) var tmp [4]byte // copy |val| to |tmp| tmp[3], tmp[2] = val[3], val[2] @@ -306,7 +307,7 @@ func WriteBytes(buf, val []byte, coll Collation) []byte { func expectSize(buf []byte, sz ByteSize) { if ByteSize(len(buf)) != sz { - panic("byte slice is not of expected size") + panic(fmt.Sprintf("byte slice is length %v expected %v", len(buf), sz)) } } From 1ee3afeb36824cc2c01659442122e8dfe033e63c Mon Sep 17 00:00:00 2001 From: jycor Date: Mon, 20 Oct 2025 22:44:43 +0000 Subject: [PATCH 19/59] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/types/bit.go | 2 +- sql/types/datetime.go | 2 +- sql/types/time.go | 2 +- sql/types/year.go | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/types/bit.go b/sql/types/bit.go index 68227c4f8e..7026c16ca7 100644 --- a/sql/types/bit.go +++ b/sql/types/bit.go @@ -18,7 +18,6 @@ import ( "context" "encoding/binary" "fmt" - "github.com/dolthub/go-mysql-server/sql/values" "reflect" "strconv" @@ -28,6 +27,7 @@ import ( "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/values" ) const ( diff --git a/sql/types/datetime.go b/sql/types/datetime.go index 38eea46e3a..b7c866e037 100644 --- a/sql/types/datetime.go +++ b/sql/types/datetime.go @@ -17,7 +17,6 @@ package types import ( "context" "fmt" - "github.com/dolthub/go-mysql-server/sql/values" "math" "reflect" "time" @@ -29,6 +28,7 @@ import ( "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/values" ) const ZeroDateStr = "0000-00-00" diff --git a/sql/types/time.go b/sql/types/time.go index 0adf802c90..696d278f08 100644 --- a/sql/types/time.go +++ b/sql/types/time.go @@ -16,7 +16,6 @@ package types import ( "context" - "github.com/dolthub/go-mysql-server/sql/values" "math" "reflect" "strconv" @@ -29,6 +28,7 @@ import ( "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/values" ) var ( diff --git a/sql/types/year.go b/sql/types/year.go index 2362408035..19c9ae5ef0 100644 --- a/sql/types/year.go +++ b/sql/types/year.go @@ -16,7 +16,6 @@ package types import ( "context" - "github.com/dolthub/go-mysql-server/sql/values" "reflect" "strconv" "time" @@ -27,6 +26,7 @@ import ( "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/values" ) var ( From 20bd197fecd08a0063b2b2b4a85e33c4b01179eb Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 21 Oct 2025 11:18:36 -0700 Subject: [PATCH 20/59] added collations --- sql/types/datetime.go | 2 +- sql/types/enum.go | 28 ++++++++++++++++++++++++++++ sql/types/set.go | 30 ++++++++++++++++++++++++++++++ sql/types/strings.go | 15 +++++++++++++-- 4 files changed, 72 insertions(+), 3 deletions(-) diff --git a/sql/types/datetime.go b/sql/types/datetime.go index b7c866e037..ed8deff684 100644 --- a/sql/types/datetime.go +++ b/sql/types/datetime.go @@ -481,7 +481,7 @@ func (t datetimeType) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sq } switch t.baseType { case sqltypes.Date: - // TODO: move this to values + // TODO: move this to values package x := values.ReadUint32(v.Val) y := x >> 16 m := (x & (255 << 8)) >> 8 diff --git a/sql/types/enum.go b/sql/types/enum.go index 3dcfa27147..1524690ff4 100644 --- a/sql/types/enum.go +++ b/sql/types/enum.go @@ -17,6 +17,7 @@ package types import ( "context" "fmt" + "github.com/dolthub/go-mysql-server/sql/values" "reflect" "strconv" "strings" @@ -268,6 +269,33 @@ func (t EnumType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va return sqltypes.MakeTrusted(sqltypes.Enum, val), nil } +func (t EnumType) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { + if v.IsNull() { + return sqltypes.NULL, nil + } + + idx := values.ReadUint16(v.Val) + value, _ := t.At(int(idx)) + + charset := ctx.GetCharacterSetResults() + if charset == sql.CharacterSet_Unspecified || charset == sql.CharacterSet_binary { + charset = t.collation.CharacterSet() + } + + // TODO: write append style encoder + res, ok := charset.Encoder().Encode(encodings.StringToBytes(value)) // TODO: use unsafe string to byte + if !ok { + // return snippet of the converted value + if len(value) > 50 { + value = value[:50] + } + value = strings.ToValidUTF8(value, string(utf8.RuneError)) + return sqltypes.Value{}, sql.ErrCharSetFailedToEncode.New(charset.Name(), utf8.ValidString(value), value) + } + + return sqltypes.MakeTrusted(sqltypes.Enum, res), nil +} + // String implements Type interface. func (t EnumType) String() string { return t.StringWithTableCollation(sql.Collation_Default) diff --git a/sql/types/set.go b/sql/types/set.go index 98b96f1390..274b6401ad 100644 --- a/sql/types/set.go +++ b/sql/types/set.go @@ -17,6 +17,7 @@ package types import ( "context" "fmt" + "github.com/dolthub/go-mysql-server/sql/values" "math" "math/bits" "reflect" @@ -261,6 +262,35 @@ func (t SetType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Val return sqltypes.MakeTrusted(sqltypes.Set, val), nil } +func (t SetType) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { + if v.IsNull() { + return sqltypes.NULL, nil + } + + bits := values.ReadUint64(v.Val) + value, err := t.BitsToString(bits) + if err != nil { + return sqltypes.Value{}, err + } + + resultCharset := ctx.GetCharacterSetResults() + if resultCharset == sql.CharacterSet_Unspecified || resultCharset == sql.CharacterSet_binary { + resultCharset = t.collation.CharacterSet() + } + + // TODO: write append style encoder + res, ok := resultCharset.Encoder().Encode(encodings.StringToBytes(value)) // TODO: use unsafe string to byte + if !ok { + if len(value) > 50 { + value = value[:50] + } + value = strings.ToValidUTF8(value, string(utf8.RuneError)) + return sqltypes.Value{}, sql.ErrCharSetFailedToEncode.New(resultCharset.Name(), utf8.ValidString(value), value) + } + + return sqltypes.MakeTrusted(sqltypes.Set, res), nil +} + // String implements Type interface. func (t SetType) String() string { return t.StringWithTableCollation(sql.Collation_Default) diff --git a/sql/types/strings.go b/sql/types/strings.go index db66946aa7..067feb1522 100644 --- a/sql/types/strings.go +++ b/sql/types/strings.go @@ -796,7 +796,6 @@ func (t StringType) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqlt return sqltypes.NULL, nil } - // TODO: collations // TODO: deal with casting numbers? // No need to use dest buffer as we have already allocated []byte var err error @@ -806,8 +805,20 @@ func (t StringType) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqlt return sqltypes.Value{}, err } } + charset := ctx.GetCharacterSetResults() + if charset == sql.CharacterSet_Unspecified || charset == sql.CharacterSet_binary { + charset = t.collation.CharacterSet() + } + res, ok := charset.Encoder().Encode(v.Val) + if !ok { + if len(v.Val) > 50 { + v.Val = v.Val[:50] + } + snippetStr := strings2.ToValidUTF8(string(v.Val), string(utf8.RuneError)) + return sqltypes.Value{}, sql.ErrCharSetFailedToEncode.New(charset.Name(), utf8.ValidString(snippetStr), v.Val) + } - return sqltypes.MakeTrusted(t.baseType, v.Val), nil + return sqltypes.MakeTrusted(t.baseType, res), nil } // String implements Type interface. From 474cb6ce8021cde39c34ad4ad3b8dbf6ab8cbb00 Mon Sep 17 00:00:00 2001 From: jycor Date: Tue, 21 Oct 2025 18:21:40 +0000 Subject: [PATCH 21/59] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/types/enum.go | 2 +- sql/types/set.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/types/enum.go b/sql/types/enum.go index 1524690ff4..fdbe81d7d5 100644 --- a/sql/types/enum.go +++ b/sql/types/enum.go @@ -17,7 +17,6 @@ package types import ( "context" "fmt" - "github.com/dolthub/go-mysql-server/sql/values" "reflect" "strconv" "strings" @@ -30,6 +29,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/encodings" + "github.com/dolthub/go-mysql-server/sql/values" ) const ( diff --git a/sql/types/set.go b/sql/types/set.go index 274b6401ad..601e10b85a 100644 --- a/sql/types/set.go +++ b/sql/types/set.go @@ -17,7 +17,6 @@ package types import ( "context" "fmt" - "github.com/dolthub/go-mysql-server/sql/values" "math" "math/bits" "reflect" @@ -31,6 +30,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/encodings" + "github.com/dolthub/go-mysql-server/sql/values" ) const ( From 13e882b933fe57cc446e6fe09d30b513696db9a2 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 21 Oct 2025 15:26:28 -0700 Subject: [PATCH 22/59] refactor --- server/handler.go | 10 +++++----- sql/plan/filter.go | 14 +++++++------- sql/plan/process.go | 12 ++++++------ sql/rowexec/transaction_iters.go | 12 ++++++------ sql/rows.go | 6 +++--- sql/table_iter.go | 20 ++++++++++---------- 6 files changed, 37 insertions(+), 37 deletions(-) diff --git a/server/handler.go b/server/handler.go index 4337b01411..58a9da8640 100644 --- a/server/handler.go +++ b/server/handler.go @@ -495,8 +495,8 @@ func (h *Handler) doQuery( r, err = resultForEmptyIter(sqlCtx, rowIter, resultFields) } else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) { r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf) - } else if ri2, ok := rowIter.(sql.RowIter2); ok && ri2.IsRowIter2(sqlCtx) { - r, processedAtLeastOneBatch, err = h.resultForDefaultIter2(sqlCtx, c, schema, ri2, resultFields, buf, callback, more) + } else if vr, ok := rowIter.(sql.ValueRowIter); ok && vr.CanSupport(sqlCtx) { + r, processedAtLeastOneBatch, err = h.resultForValueRowIter(sqlCtx, c, schema, vr, resultFields, buf, callback, more) } else { r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf) } @@ -770,8 +770,8 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s return r, processedAtLeastOneBatch, nil } -func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, schema sql.Schema, iter sql.RowIter2, resultFields []*querypb.Field, buf *sql.ByteBuffer, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, bool, error) { - defer trace.StartRegion(ctx, "Handler.resultForDefaultIter2").End() +func (h *Handler) resultForValueRowIter(ctx *sql.Context, c *mysql.Conn, schema sql.Schema, iter sql.ValueRowIter, resultFields []*querypb.Field, buf *sql.ByteBuffer, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, bool, error) { + defer trace.StartRegion(ctx, "Handler.resultForValueRowIter").End() eg, ctx := ctx.NewErrgroup() pan2err := func(err *error) { @@ -816,7 +816,7 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, schema case <-ctx.Done(): return context.Cause(ctx) default: - row, err := iter.Next2(ctx) + row, err := iter.NextValueRow(ctx) if err == io.EOF { return nil } diff --git a/sql/plan/filter.go b/sql/plan/filter.go index 4dfff0e359..dd8417091c 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -106,11 +106,11 @@ type FilterIter struct { childIter sql.RowIter cond2 sql.Expression2 - childIter2 sql.RowIter2 + childIter2 sql.ValueRowIter } var _ sql.RowIter = (*FilterIter)(nil) -var _ sql.RowIter2 = (*FilterIter)(nil) +var _ sql.ValueRowIter = (*FilterIter)(nil) // NewFilterIter creates a new FilterIter. func NewFilterIter( @@ -139,9 +139,9 @@ func (i *FilterIter) Next(ctx *sql.Context) (sql.Row, error) { } } -func (i *FilterIter) Next2(ctx *sql.Context) (sql.ValueRow, error) { +func (i *FilterIter) NextValueRow(ctx *sql.Context) (sql.ValueRow, error) { for { - row, err := i.childIter2.Next2(ctx) + row, err := i.childIter2.NextValueRow(ctx) if err != nil { return nil, err } @@ -155,13 +155,13 @@ func (i *FilterIter) Next2(ctx *sql.Context) (sql.ValueRow, error) { } } -func (i *FilterIter) IsRowIter2(ctx *sql.Context) bool { +func (i *FilterIter) CanSupport(ctx *sql.Context) bool { cond, ok := i.cond.(sql.Expression2) if !ok || !cond.IsExpr2() { return false } - childIter, ok := i.childIter.(sql.RowIter2) - if !ok || !childIter.IsRowIter2(ctx) { + childIter, ok := i.childIter.(sql.ValueRowIter) + if !ok || !childIter.CanSupport(ctx) { return false } i.cond2 = cond diff --git a/sql/plan/process.go b/sql/plan/process.go index e9974bbd72..2aad92531e 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -226,7 +226,7 @@ const ( type TrackedRowIter struct { node sql.Node iter sql.RowIter - iter2 sql.RowIter2 + iter2 sql.ValueRowIter onDone NotifyFunc onNext NotifyFunc numRows int64 @@ -318,8 +318,8 @@ func (i *TrackedRowIter) Next(ctx *sql.Context) (sql.Row, error) { return row, nil } -func (i *TrackedRowIter) Next2(ctx *sql.Context) (sql.ValueRow, error) { - row, err := i.iter2.Next2(ctx) +func (i *TrackedRowIter) NextValueRow(ctx *sql.Context) (sql.ValueRow, error) { + row, err := i.iter2.NextValueRow(ctx) if err != nil { return nil, err } @@ -330,9 +330,9 @@ func (i *TrackedRowIter) Next2(ctx *sql.Context) (sql.ValueRow, error) { return row, nil } -func (i *TrackedRowIter) IsRowIter2(ctx *sql.Context) bool { - iter, ok := i.iter.(sql.RowIter2) - if !ok || !iter.IsRowIter2(ctx) { +func (i *TrackedRowIter) CanSupport(ctx *sql.Context) bool { + iter, ok := i.iter.(sql.ValueRowIter) + if !ok || !iter.CanSupport(ctx) { return false } i.iter2 = iter diff --git a/sql/rowexec/transaction_iters.go b/sql/rowexec/transaction_iters.go index f2f466d281..5fe36df669 100644 --- a/sql/rowexec/transaction_iters.go +++ b/sql/rowexec/transaction_iters.go @@ -71,7 +71,7 @@ func getLockableTable(table sql.Table) (sql.Lockable, error) { // during the Close() operation type TransactionCommittingIter struct { childIter sql.RowIter - childIter2 sql.RowIter2 + childIter2 sql.ValueRowIter transactionDatabase string autoCommit bool implicitCommit bool @@ -100,13 +100,13 @@ func (t *TransactionCommittingIter) Next(ctx *sql.Context) (sql.Row, error) { return t.childIter.Next(ctx) } -func (t *TransactionCommittingIter) Next2(ctx *sql.Context) (sql.ValueRow, error) { - return t.childIter2.Next2(ctx) +func (t *TransactionCommittingIter) NextValueRow(ctx *sql.Context) (sql.ValueRow, error) { + return t.childIter2.NextValueRow(ctx) } -func (t *TransactionCommittingIter) IsRowIter2(ctx *sql.Context) bool { - childIter, ok := t.childIter.(sql.RowIter2) - if !ok || !childIter.IsRowIter2(ctx) { +func (t *TransactionCommittingIter) CanSupport(ctx *sql.Context) bool { + childIter, ok := t.childIter.(sql.ValueRowIter) + if !ok || !childIter.CanSupport(ctx) { return false } t.childIter2 = childIter diff --git a/sql/rows.go b/sql/rows.go index 9f8ef3d1c5..04fa261316 100644 --- a/sql/rows.go +++ b/sql/rows.go @@ -88,10 +88,10 @@ type RowIter interface { Closer } -type RowIter2 interface { +type ValueRowIter interface { RowIter - Next2(ctx *Context) (ValueRow, error) - IsRowIter2(ctx *Context) bool + NextValueRow(ctx *Context) (ValueRow, error) + CanSupport(ctx *Context) bool } // RowIterToRows converts a row iterator to a slice of rows. diff --git a/sql/table_iter.go b/sql/table_iter.go index 8f4b12fa30..8ef451649f 100644 --- a/sql/table_iter.go +++ b/sql/table_iter.go @@ -26,7 +26,7 @@ type TableRowIter struct { partition Partition rows RowIter - rows2 RowIter2 + rows2 ValueRowIter } var _ RowIter = (*TableRowIter)(nil) @@ -79,7 +79,7 @@ func (i *TableRowIter) Next(ctx *Context) (Row, error) { return row, err } -func (i *TableRowIter) Next2(ctx *Context) (ValueRow, error) { +func (i *TableRowIter) NextValueRow(ctx *Context) (ValueRow, error) { select { case <-ctx.Done(): return nil, ctx.Err() @@ -105,26 +105,26 @@ func (i *TableRowIter) Next2(ctx *Context) (ValueRow, error) { if err != nil { return nil, err } - ri2, ok := rows.(RowIter2) - if !ok || !ri2.IsRowIter2(ctx) { - panic(fmt.Sprintf("%T does not implement RowIter2", rows)) + ri2, ok := rows.(ValueRowIter) + if !ok || !ri2.CanSupport(ctx) { + panic(fmt.Sprintf("%T does not implement ValueRowIter", rows)) } i.rows2 = ri2 } - row, err := i.rows2.Next2(ctx) + row, err := i.rows2.NextValueRow(ctx) if err != nil && err == io.EOF { if err = i.rows2.Close(ctx); err != nil { return nil, err } i.partition = nil i.rows2 = nil - row, err = i.Next2(ctx) + row, err = i.NextValueRow(ctx) } return row, err } -func (i *TableRowIter) IsRowIter2(ctx *Context) bool { +func (i *TableRowIter) CanSupport(ctx *Context) bool { if i.partition == nil { partition, err := i.partitions.Next(ctx) if err != nil { @@ -137,13 +137,13 @@ func (i *TableRowIter) IsRowIter2(ctx *Context) bool { if err != nil { return false } - ri2, ok := rows.(RowIter2) + ri2, ok := rows.(ValueRowIter) if !ok { return false } i.rows2 = ri2 } - return i.rows2.IsRowIter2(ctx) + return i.rows2.CanSupport(ctx) } func (i *TableRowIter) Close(ctx *Context) error { From 15f6b872dad56fa8264b735367675f8d078de05e Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 21 Oct 2025 23:50:38 -0700 Subject: [PATCH 23/59] fix types over wire --- sql/types/bit.go | 29 ++++++++++++++++++++++------- sql/types/datetime.go | 2 +- sql/types/decimal.go | 17 +++++++++++++++++ sql/types/year.go | 2 +- sql/values/encoding.go | 8 ++------ 5 files changed, 43 insertions(+), 15 deletions(-) diff --git a/sql/types/bit.go b/sql/types/bit.go index 7026c16ca7..546a46981c 100644 --- a/sql/types/bit.go +++ b/sql/types/bit.go @@ -18,16 +18,13 @@ import ( "context" "encoding/binary" "fmt" - "reflect" - "strconv" - "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" + "reflect" "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/values" ) const ( @@ -218,9 +215,27 @@ func (t BitType_) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltyp if v.IsNull() { return sqltypes.NULL, nil } - // Assume this is uint64 - x := values.ReadUint64(v.Val) - dest = strconv.AppendUint(dest, x, 10) + + numBytes := t.numOfBits / 8 + if t.numOfBits%8 != 0 { + numBytes += 1 + } + + if uint8(len(v.Val)) < numBytes { + // already in little endian, so just pad with trailing 0s + for i := uint8(len(v.Val)); i <= t.numOfBits/8; i++ { + v.Val = append(v.Val, 0) + } + } else { + v.Val = v.Val[:numBytes] + } + + // TODO: for whatever reason, TestTypesOverWire only works when this is a deep copy? + dest = append(dest, v.Val...) + // want the results in big endian + for i, j := 0, len(dest)-1; i < j; i, j = i+1, j-1 { + dest[i], dest[j] = dest[j], dest[i] + } return sqltypes.MakeTrusted(sqltypes.Bit, dest), nil } diff --git a/sql/types/datetime.go b/sql/types/datetime.go index ed8deff684..54262009d3 100644 --- a/sql/types/datetime.go +++ b/sql/types/datetime.go @@ -487,7 +487,7 @@ func (t datetimeType) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sq m := (x & (255 << 8)) >> 8 d := x & 255 t := time.Date(int(y), time.Month(m), int(d), 0, 0, 0, 0, time.UTC) - dest = t.AppendFormat(dest, sql.TimestampDatetimeLayout) + dest = t.AppendFormat(dest, sql.DateLayout) case sqltypes.Datetime, sqltypes.Timestamp: x := values.ReadInt64(v.Val) diff --git a/sql/types/decimal.go b/sql/types/decimal.go index ccfa6eb321..c876b2a82c 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -17,6 +17,7 @@ package types import ( "context" "fmt" + "github.com/dolthub/go-mysql-server/sql/values" "math/big" "reflect" "strings" @@ -329,6 +330,22 @@ func (t DecimalType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltype return sqltypes.MakeTrusted(sqltypes.Decimal, val), nil } +func (t DecimalType_) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { + if v.IsNull() { + return sqltypes.NULL, nil + } + // TODO: implement values.ReadDecimal + e := values.ReadInt32(v.Val[:values.Int32Size]) + s := values.ReadInt8(v.Val[values.Int32Size : values.Int32Size+values.Int8Size]) + b := big.NewInt(0).SetBytes(v.Val[values.Int32Size+values.Int8Size:]) + if s < 0 { + b = b.Neg(b) + } + d := decimal.NewFromBigInt(b, e) + val := AppendAndSliceString(dest, t.DecimalValueStringFixed(d)) + return sqltypes.MakeTrusted(sqltypes.Decimal, val), nil +} + // String implements Type interface. func (t DecimalType_) String() string { return fmt.Sprintf("decimal(%v,%v)", t.precision, t.scale) diff --git a/sql/types/year.go b/sql/types/year.go index 19c9ae5ef0..91ac78e280 100644 --- a/sql/types/year.go +++ b/sql/types/year.go @@ -176,7 +176,7 @@ func (t YearType_) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqlty if v.IsNull() { return sqltypes.NULL, nil } - x := values.ReadUint8(v.Val) + x := values.ReadUint16(v.Val) dest = strconv.AppendInt(dest, int64(x), 10) return sqltypes.MakeTrusted(sqltypes.Year, dest), nil } diff --git a/sql/values/encoding.go b/sql/values/encoding.go index 45e9445b18..5db8eb3af4 100644 --- a/sql/values/encoding.go +++ b/sql/values/encoding.go @@ -131,17 +131,13 @@ func ReadUint16(val []byte) uint16 { func ReadInt24(val []byte) (i int32) { expectSize(val, Int32Size) - i = int32(binary.LittleEndian.Uint32([]byte{0, val[0], val[1], val[2]})) + i = int32(binary.LittleEndian.Uint32(val)) return } func ReadUint24(val []byte) (u uint32) { expectSize(val, Int32Size) - var tmp [4]byte - // copy |val| to |tmp| - tmp[3], tmp[2] = val[3], val[2] - tmp[1], tmp[0] = val[1], val[0] - u = binary.LittleEndian.Uint32(tmp[:]) + u = binary.LittleEndian.Uint32(val) return } From 3fa454f2d8b5b23dee2e3b4a34eae4f2965a9af1 Mon Sep 17 00:00:00 2001 From: jycor Date: Wed, 22 Oct 2025 07:01:01 +0000 Subject: [PATCH 24/59] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/types/bit.go | 3 ++- sql/types/decimal.go | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/types/bit.go b/sql/types/bit.go index 546a46981c..39d44dc391 100644 --- a/sql/types/bit.go +++ b/sql/types/bit.go @@ -18,11 +18,12 @@ import ( "context" "encoding/binary" "fmt" + "reflect" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" - "reflect" "github.com/dolthub/go-mysql-server/sql" ) diff --git a/sql/types/decimal.go b/sql/types/decimal.go index c876b2a82c..8164916a4d 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -17,7 +17,6 @@ package types import ( "context" "fmt" - "github.com/dolthub/go-mysql-server/sql/values" "math/big" "reflect" "strings" @@ -28,6 +27,7 @@ import ( "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/values" ) const ( From f3cf2a3243fa2d0fefc6768dac8b82b0d6feff28 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 22 Oct 2025 13:40:19 -0700 Subject: [PATCH 25/59] fix nulls --- server/handler.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/server/handler.go b/server/handler.go index 58a9da8640..51dcb7d1bf 100644 --- a/server/handler.go +++ b/server/handler.go @@ -1181,6 +1181,10 @@ func RowValueToSQLValues(ctx *sql.Context, sch sql.Schema, row sql.ValueRow, buf // TODO: remove this check once all Types implement this valType, ok := col.Type.(sql.Type2) if !ok { + if row[i].IsNull() { + outVals[i] = sqltypes.NULL + continue + } outVals[i] = sqltypes.MakeTrusted(row[i].Typ, row[i].Val) continue } From dbdeb2171a0205b152637885758444dd21ee2354 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 23 Oct 2025 10:45:19 -0700 Subject: [PATCH 26/59] refactor and comments --- sql/core.go | 13 ++++++------ sql/expression/comparison.go | 36 +++++++++++++++----------------- sql/expression/get_field.go | 13 +++++------- sql/expression/literal.go | 16 +++++--------- sql/expression/namedliteral.go | 2 +- sql/expression/unresolved.go | 16 +++++++------- sql/plan/filter.go | 8 +++---- sql/plan/indexed_table_access.go | 4 ++-- sql/sort_field.go | 6 +++--- 9 files changed, 50 insertions(+), 64 deletions(-) diff --git a/sql/core.go b/sql/core.go index 37c32f1672..77b470765a 100644 --- a/sql/core.go +++ b/sql/core.go @@ -460,14 +460,13 @@ func DebugString(nodeOrExpression interface{}) string { panic(fmt.Sprintf("Expected sql.DebugString or fmt.Stringer for %T", nodeOrExpression)) } -// Expression2 is an experimental future interface alternative to Expression to provide faster access. -type Expression2 interface { +// ValueExpression is an experimental future interface alternative to Expression to provide faster access. +type ValueExpression interface { Expression - // Eval2 evaluates the given row frame and returns a result. - Eval2(ctx *Context, row ValueRow) (Value, error) - // Type2 returns the expression type. - Type2() Type2 - IsExpr2() bool + // EvalValue evaluates the given row frame and returns a result. + EvalValue(ctx *Context, row ValueRow) (Value, error) + // CanSupport indicates whether this expression and all it's children support ValueExpression. + CanSupport() bool } var SystemVariables SystemVariableRegistry diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 0e44926350..1cf4083392 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -496,7 +496,7 @@ type GreaterThan struct { } var _ sql.Expression = (*GreaterThan)(nil) -var _ sql.Expression2 = (*GreaterThan)(nil) +var _ sql.ValueExpression = (*GreaterThan)(nil) var _ sql.CollationCoercible = (*GreaterThan)(nil) // NewGreaterThan creates a new GreaterThan expression. @@ -523,21 +523,22 @@ func (gt *GreaterThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) return result == 1, nil } -func (gt *GreaterThan) Eval2(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { - l, ok := gt.Left().(sql.Expression2) +// EvalValue implements the sql.ValueExpression interface. +func (gt *GreaterThan) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { + l, ok := gt.Left().(sql.ValueExpression) if !ok { - panic(fmt.Sprintf("%T does not implement sql.Expression2", gt.Left())) + panic(fmt.Sprintf("%T does not implement sql.ValueExpression", gt.Left())) } - r, ok := gt.Right().(sql.Expression2) + r, ok := gt.Right().(sql.ValueExpression) if !ok { - panic(fmt.Sprintf("%T does not implement sql.Expression2", gt.Right())) + panic(fmt.Sprintf("%T does not implement sql.ValueExpression", gt.Right())) } - lv, err := l.Eval2(ctx, row) + lv, err := l.EvalValue(ctx, row) if err != nil { return sql.Value{}, err } - rv, err := r.Eval2(ctx, row) + rv, err := r.EvalValue(ctx, row) if err != nil { return sql.Value{}, err } @@ -563,23 +564,20 @@ func (gt *GreaterThan) Eval2(ctx *sql.Context, row sql.ValueRow) (sql.Value, err return ret, nil } -func (gt *GreaterThan) Type2() sql.Type2 { - return nil -} - -func (gt *GreaterThan) IsExpr2() bool { - lExpr, isExpr2 := gt.Left().(sql.Expression2) - if !isExpr2 { +// CanSupport implements the ValueExpression interface. +func (gt *GreaterThan) CanSupport() bool { + l, ok := gt.comparison.LeftChild.(sql.ValueExpression) + if !ok { return false } - if !lExpr.IsExpr2() { + if !l.CanSupport() { return false } - rExpr, isExpr2 := gt.Right().(sql.Expression2) - if !isExpr2 { + r, ok := gt.comparison.RightChild.(sql.ValueExpression) + if !ok { return false } - if !rExpr.IsExpr2() { + if !r.CanSupport() { return false } return true diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index e660c4fcb8..3571d97ba5 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -47,7 +47,7 @@ type GetField struct { } var _ sql.Expression = (*GetField)(nil) -var _ sql.Expression2 = (*GetField)(nil) +var _ sql.ValueExpression = (*GetField)(nil) var _ sql.CollationCoercible = (*GetField)(nil) var _ sql.IdExpression = (*GetField)(nil) @@ -133,11 +133,6 @@ func (p *GetField) Type() sql.Type { return p.fieldType } -// Type2 returns the type of the field, if this field has a sql.Type2. -func (p *GetField) Type2() sql.Type2 { - return p.fieldType2 -} - // ErrIndexOutOfBounds is returned when the field index is out of the bounds. var ErrIndexOutOfBounds = errors.NewKind("unable to find field with index %d in row of %d columns") @@ -149,14 +144,16 @@ func (p *GetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return row[p.fieldIndex], nil } -func (p *GetField) Eval2(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { +// EvalValue implements the ValueExpression interface. +func (p *GetField) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { if p.fieldIndex < 0 || p.fieldIndex >= len(row) { return sql.Value{}, ErrIndexOutOfBounds.New(p.fieldIndex, len(row)) } return row[p.fieldIndex], nil } -func (p *GetField) IsExpr2() bool { +// CanSupport implements the ValueExpression interface. +func (p *GetField) CanSupport() bool { return true } diff --git a/sql/expression/literal.go b/sql/expression/literal.go index 4d3f9334c7..0887697aeb 100644 --- a/sql/expression/literal.go +++ b/sql/expression/literal.go @@ -34,7 +34,7 @@ type Literal struct { } var _ sql.Expression = &Literal{} -var _ sql.Expression2 = &Literal{} +var _ sql.ValueExpression = &Literal{} var _ sql.CollationCoercible = &Literal{} var _ sqlparser.Injectable = &Literal{} @@ -136,22 +136,16 @@ func (*Literal) Children() []sql.Expression { return nil } -func (lit *Literal) Eval2(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { +// EvalValue implements the sql.ValueExpression interface. +func (lit *Literal) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { return lit.val2, nil } -func (lit *Literal) IsExpr2() bool { +// CanSupport implements the ValueExpression interface. +func (lit *Literal) CanSupport() bool { return true } -func (lit *Literal) Type2() sql.Type2 { - t2, ok := lit.Typ.(sql.Type2) - if !ok { - panic(fmt.Errorf("expected Type2, but was %T", lit.Typ)) - } - return t2 -} - // Value returns the literal value. func (lit *Literal) Value() interface{} { return lit.Val diff --git a/sql/expression/namedliteral.go b/sql/expression/namedliteral.go index ebf8d80ded..ce5550dd54 100644 --- a/sql/expression/namedliteral.go +++ b/sql/expression/namedliteral.go @@ -25,7 +25,7 @@ type NamedLiteral struct { } var _ sql.Expression = NamedLiteral{} -var _ sql.Expression2 = NamedLiteral{} +var _ sql.ValueExpression = NamedLiteral{} var _ sql.CollationCoercible = NamedLiteral{} // NewNamedLiteral returns a new NamedLiteral. diff --git a/sql/expression/unresolved.go b/sql/expression/unresolved.go index 40e4799d75..b300f4c1a8 100644 --- a/sql/expression/unresolved.go +++ b/sql/expression/unresolved.go @@ -32,7 +32,7 @@ type UnresolvedColumn struct { } var _ sql.Expression = (*UnresolvedColumn)(nil) -var _ sql.Expression2 = (*UnresolvedColumn)(nil) +var _ sql.ValueExpression = (*UnresolvedColumn)(nil) var _ sql.CollationCoercible = (*UnresolvedColumn)(nil) // NewUnresolvedColumn creates a new UnresolvedColumn expression. @@ -71,16 +71,14 @@ func (*UnresolvedColumn) CollationCoercibility(ctx *sql.Context) (collation sql. return sql.Collation_binary, 7 } -func (uc *UnresolvedColumn) Eval2(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { - panic("unresolved column is a placeholder node, but Eval2 was called") +// EvalValue implements the sql.ValueExpression interface. +func (uc *UnresolvedColumn) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { + panic("unresolved column is a placeholder node, but EvalValue was called") } -func (uc *UnresolvedColumn) Type2() sql.Type2 { - panic("unresolved column is a placeholder node, but Type2 was called") -} - -func (uc *UnresolvedColumn) IsExpr2() bool { - panic("unresolved column is a placeholder node, but IsExpr2 was called") +// CanSupport implements the ValueExpression interface. +func (uc *UnresolvedColumn) CanSupport() bool { + panic("unresolved column is a placeholder node, but CanSupport was called") } // Name implements the Nameable interface. diff --git a/sql/plan/filter.go b/sql/plan/filter.go index dd8417091c..f8c1ccafe9 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -105,7 +105,7 @@ type FilterIter struct { cond sql.Expression childIter sql.RowIter - cond2 sql.Expression2 + cond2 sql.ValueExpression childIter2 sql.ValueRowIter } @@ -145,7 +145,7 @@ func (i *FilterIter) NextValueRow(ctx *sql.Context) (sql.ValueRow, error) { if err != nil { return nil, err } - res, err := i.cond2.Eval2(ctx, row) + res, err := i.cond2.EvalValue(ctx, row) if err != nil { return nil, err } @@ -156,8 +156,8 @@ func (i *FilterIter) NextValueRow(ctx *sql.Context) (sql.ValueRow, error) { } func (i *FilterIter) CanSupport(ctx *sql.Context) bool { - cond, ok := i.cond.(sql.Expression2) - if !ok || !cond.IsExpr2() { + cond, ok := i.cond.(sql.ValueExpression) + if !ok || !cond.CanSupport() { return false } childIter, ok := i.childIter.(sql.ValueRowIter) diff --git a/sql/plan/indexed_table_access.go b/sql/plan/indexed_table_access.go index ee768c8a1a..8afbdf7dd9 100644 --- a/sql/plan/indexed_table_access.go +++ b/sql/plan/indexed_table_access.go @@ -502,7 +502,7 @@ type lookupBuilderKey []interface{} type LookupBuilder struct { index sql.Index keyExprs []sql.Expression - keyExprs2 []sql.Expression2 + keyExprs2 []sql.ValueExpression // When building the lookup, we will use an MySQLIndexBuilder. If the // extracted lookup value is NULL, but we have a non-NULL safe // comparison, then the lookup should return no values. But if the @@ -642,7 +642,7 @@ func (lb *LookupBuilder) GetKey2(ctx *sql.Context, row sql.ValueRow) (lookupBuil } for i := range lb.keyExprs { var err error - lb.key[i], err = lb.keyExprs2[i].Eval2(ctx, row) + lb.key[i], err = lb.keyExprs2[i].EvalValue(ctx, row) if err != nil { return nil, err } diff --git a/sql/sort_field.go b/sql/sort_field.go index 02b844a07f..950cf7074a 100644 --- a/sql/sort_field.go +++ b/sql/sort_field.go @@ -24,8 +24,8 @@ import ( type SortField struct { // Column to order by. Column Expression - // Column Expression2 to order by. This is always the same value as Column, but avoids a type cast - Column2 Expression2 + // Column ValueExpression to order by. This is always the same value as Column, but avoids a type cast + Column2 ValueExpression // Order type. Order SortOrder // NullOrdering defining how nulls will be ordered. @@ -50,7 +50,7 @@ func (sf SortFields) FromExpressions(exprs ...Expression) SortFields { } for i, expr := range exprs { - expr2, _ := expr.(Expression2) + expr2, _ := expr.(ValueExpression) fields[i] = SortField{ Column: expr, Column2: expr2, From 8f27277e43e59c4cf10b9bfe9681a78946fa563a Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 23 Oct 2025 12:04:13 -0700 Subject: [PATCH 27/59] clean up --- server/handler.go | 5 ++- sql/expression/comparison.go | 64 +++++++++++++++------------- sql/expression/literal.go | 2 +- sql/plan/filter.go | 10 ++--- sql/plan/process.go | 9 +--- sql/row_frame.go | 12 +++--- sql/rowexec/transaction_iters.go | 9 +--- sql/table_iter.go | 28 +++++-------- sql/types/number.go | 72 ++++++++++++++++---------------- 9 files changed, 98 insertions(+), 113 deletions(-) diff --git a/server/handler.go b/server/handler.go index 51dcb7d1bf..63da5425bc 100644 --- a/server/handler.go +++ b/server/handler.go @@ -702,7 +702,10 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s defer wg.Done() for { if r == nil { - r = &sqltypes.Result{Fields: resultFields} + r = &sqltypes.Result{ + Rows: make([][]sqltypes.Value, 0, rowsBatch), + Fields: resultFields, + } } if r.RowsAffected == rowsBatch { if err := resetCallback(r, more); err != nil { diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 1cf4083392..06ea5472a4 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -16,12 +16,11 @@ package expression import ( "fmt" - - querypb "github.com/dolthub/vitess/go/vt/proto/query" - errors "gopkg.in/src-d/go-errors.v1" + "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" + "github.com/dolthub/vitess/go/sqltypes" ) var ErrInvalidRegexp = errors.NewKind("Invalid regular expression: %s") @@ -525,43 +524,48 @@ func (gt *GreaterThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) // EvalValue implements the sql.ValueExpression interface. func (gt *GreaterThan) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { - l, ok := gt.Left().(sql.ValueExpression) - if !ok { - panic(fmt.Sprintf("%T does not implement sql.ValueExpression", gt.Left())) - } - r, ok := gt.Right().(sql.ValueExpression) - if !ok { - panic(fmt.Sprintf("%T does not implement sql.ValueExpression", gt.Right())) - } - - lv, err := l.EvalValue(ctx, row) + lv, err := gt.comparison.LeftChild.(sql.ValueExpression).EvalValue(ctx, row) if err != nil { return sql.Value{}, err } - rv, err := r.EvalValue(ctx, row) + rv, err := gt.comparison.RightChild.(sql.ValueExpression).EvalValue(ctx, row) if err != nil { return sql.Value{}, err } - // TODO: just assume they are int64 - l64, err := types.ConvertValueToInt64(types.NumberTypeImpl_{}, lv) - if err != nil { - return sql.Value{}, err - } - r64, err := types.ConvertValueToInt64(types.NumberTypeImpl_{}, rv) - if err != nil { - return sql.Value{}, err - } - var rb byte - if l64 > r64 { - rb = 1 + // TODO: move this logic into comparison + var cmp byte + if sqltypes.IsUnsigned(lv.Typ) && sqltypes.IsUnsigned(rv.Typ) { + l, cErr := types.ConvertValueToUint64(lv) + if cErr != nil { + return sql.Value{}, cErr + } + r, cErr := types.ConvertValueToUint64(rv) + if cErr != nil { + return sql.Value{}, cErr + } + if l > r { + cmp = 1 + } + } else { + l, cErr := types.ConvertValueToInt64(lv) + if cErr != nil { + return sql.Value{}, cErr + } + r, cErr := types.ConvertValueToInt64(rv) + if cErr != nil { + return sql.Value{}, cErr + } + if l > r { + cmp = 1 + } } - ret := sql.Value{ - Val: []byte{rb}, - Typ: querypb.Type_INT8, + res := sql.Value{ + Val: []byte{cmp}, + Typ: sqltypes.Int8, } - return ret, nil + return res, nil } // CanSupport implements the ValueExpression interface. diff --git a/sql/expression/literal.go b/sql/expression/literal.go index 0887697aeb..0a1122a429 100644 --- a/sql/expression/literal.go +++ b/sql/expression/literal.go @@ -143,7 +143,7 @@ func (lit *Literal) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, er // CanSupport implements the ValueExpression interface. func (lit *Literal) CanSupport() bool { - return true + return types.IsInteger(lit.Typ) } // Value returns the literal value. diff --git a/sql/plan/filter.go b/sql/plan/filter.go index f8c1ccafe9..bd70107779 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -104,9 +104,6 @@ func (f *Filter) Expressions() []sql.Expression { type FilterIter struct { cond sql.Expression childIter sql.RowIter - - cond2 sql.ValueExpression - childIter2 sql.ValueRowIter } var _ sql.RowIter = (*FilterIter)(nil) @@ -141,11 +138,11 @@ func (i *FilterIter) Next(ctx *sql.Context) (sql.Row, error) { func (i *FilterIter) NextValueRow(ctx *sql.Context) (sql.ValueRow, error) { for { - row, err := i.childIter2.NextValueRow(ctx) + row, err := i.childIter.(sql.ValueRowIter).NextValueRow(ctx) if err != nil { return nil, err } - res, err := i.cond2.EvalValue(ctx, row) + res, err := i.cond.(sql.ValueExpression).EvalValue(ctx, row) if err != nil { return nil, err } @@ -155,6 +152,7 @@ func (i *FilterIter) NextValueRow(ctx *sql.Context) (sql.ValueRow, error) { } } +// CanSupport implements the sql.ValueRowIter interface. func (i *FilterIter) CanSupport(ctx *sql.Context) bool { cond, ok := i.cond.(sql.ValueExpression) if !ok || !cond.CanSupport() { @@ -164,8 +162,6 @@ func (i *FilterIter) CanSupport(ctx *sql.Context) bool { if !ok || !childIter.CanSupport(ctx) { return false } - i.cond2 = cond - i.childIter2 = childIter return true } diff --git a/sql/plan/process.go b/sql/plan/process.go index 2aad92531e..2b23d4aa6d 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -226,7 +226,6 @@ const ( type TrackedRowIter struct { node sql.Node iter sql.RowIter - iter2 sql.ValueRowIter onDone NotifyFunc onNext NotifyFunc numRows int64 @@ -319,7 +318,7 @@ func (i *TrackedRowIter) Next(ctx *sql.Context) (sql.Row, error) { } func (i *TrackedRowIter) NextValueRow(ctx *sql.Context) (sql.ValueRow, error) { - row, err := i.iter2.NextValueRow(ctx) + row, err := i.iter.(sql.ValueRowIter).NextValueRow(ctx) if err != nil { return nil, err } @@ -332,11 +331,7 @@ func (i *TrackedRowIter) NextValueRow(ctx *sql.Context) (sql.ValueRow, error) { func (i *TrackedRowIter) CanSupport(ctx *sql.Context) bool { iter, ok := i.iter.(sql.ValueRowIter) - if !ok || !iter.CanSupport(ctx) { - return false - } - i.iter2 = iter - return true + return ok && iter.CanSupport(ctx) } func (i *TrackedRowIter) Close(ctx *sql.Context) error { diff --git a/sql/row_frame.go b/sql/row_frame.go index 067c24c400..a648ef7edb 100644 --- a/sql/row_frame.go +++ b/sql/row_frame.go @@ -17,7 +17,7 @@ package sql import ( "sync" - querypb "github.com/dolthub/vitess/go/vt/proto/query" + "github.com/dolthub/vitess/go/vt/proto/query" ) const ( @@ -31,7 +31,7 @@ type ValueBytes []byte type Value struct { Val ValueBytes WrappedVal BytesWrapper - Typ querypb.Type // TODO: consider sqltypes.Type instead + Typ query.Type } // ValueRow is a slice of values @@ -39,11 +39,11 @@ type ValueRow []Value // IsNull returns whether this value represents NULL func (v Value) IsNull() bool { - return (v.Val == nil && v.WrappedVal == nil) || v.Typ == querypb.Type_NULL_TYPE + return (v.Val == nil && v.WrappedVal == nil) || v.Typ == query.Type_NULL_TYPE } type RowFrame struct { - Types []querypb.Type + Types []query.Type // Values are the values this row. Values []ValueBytes @@ -128,7 +128,7 @@ func (f *RowFrame) Append(vals ...Value) { } // AppendMany appends the types and values given, as two parallel arrays, into this frame. -func (f *RowFrame) AppendMany(types []querypb.Type, vals []ValueBytes) { +func (f *RowFrame) AppendMany(types []query.Type, vals []ValueBytes) { // TODO: one big copy here would be better probably, need to benchmark for i := range vals { f.appendTypeAndVal(types[i], vals[i]) @@ -147,7 +147,7 @@ func (f *RowFrame) append(v Value) { f.Values = append(f.Values, v.Val) } -func (f *RowFrame) appendTypeAndVal(typ querypb.Type, val ValueBytes) { +func (f *RowFrame) appendTypeAndVal(typ query.Type, val ValueBytes) { v := f.bufferForBytes(val) copy(v, val) diff --git a/sql/rowexec/transaction_iters.go b/sql/rowexec/transaction_iters.go index 5fe36df669..fda46f31a6 100644 --- a/sql/rowexec/transaction_iters.go +++ b/sql/rowexec/transaction_iters.go @@ -71,7 +71,6 @@ func getLockableTable(table sql.Table) (sql.Lockable, error) { // during the Close() operation type TransactionCommittingIter struct { childIter sql.RowIter - childIter2 sql.ValueRowIter transactionDatabase string autoCommit bool implicitCommit bool @@ -101,16 +100,12 @@ func (t *TransactionCommittingIter) Next(ctx *sql.Context) (sql.Row, error) { } func (t *TransactionCommittingIter) NextValueRow(ctx *sql.Context) (sql.ValueRow, error) { - return t.childIter2.NextValueRow(ctx) + return t.childIter.(sql.ValueRowIter).NextValueRow(ctx) } func (t *TransactionCommittingIter) CanSupport(ctx *sql.Context) bool { childIter, ok := t.childIter.(sql.ValueRowIter) - if !ok || !childIter.CanSupport(ctx) { - return false - } - t.childIter2 = childIter - return true + return ok && childIter.CanSupport(ctx) } func (t *TransactionCommittingIter) Close(ctx *sql.Context) error { diff --git a/sql/table_iter.go b/sql/table_iter.go index 8ef451649f..56de252d0c 100644 --- a/sql/table_iter.go +++ b/sql/table_iter.go @@ -15,7 +15,6 @@ package sql import ( - "fmt" "io" ) @@ -25,8 +24,6 @@ type TableRowIter struct { partitions PartitionIter partition Partition rows RowIter - - rows2 ValueRowIter } var _ RowIter = (*TableRowIter)(nil) @@ -79,6 +76,7 @@ func (i *TableRowIter) Next(ctx *Context) (Row, error) { return row, err } +// NextValueRow implements the sql.ValueRowIter interface func (i *TableRowIter) NextValueRow(ctx *Context) (ValueRow, error) { select { case <-ctx.Done(): @@ -93,32 +91,26 @@ func (i *TableRowIter) NextValueRow(ctx *Context) (ValueRow, error) { return nil, e } } - return nil, err } - i.partition = partition } - if i.rows2 == nil { + if i.rows == nil { rows, err := i.table.PartitionRows(ctx, i.partition) if err != nil { return nil, err } - ri2, ok := rows.(ValueRowIter) - if !ok || !ri2.CanSupport(ctx) { - panic(fmt.Sprintf("%T does not implement ValueRowIter", rows)) - } - i.rows2 = ri2 + i.rows = rows } - row, err := i.rows2.NextValueRow(ctx) + row, err := i.rows.(ValueRowIter).NextValueRow(ctx) if err != nil && err == io.EOF { - if err = i.rows2.Close(ctx); err != nil { + if err = i.rows.Close(ctx); err != nil { return nil, err } i.partition = nil - i.rows2 = nil + i.rows = nil row, err = i.NextValueRow(ctx) } return row, err @@ -132,18 +124,18 @@ func (i *TableRowIter) CanSupport(ctx *Context) bool { } i.partition = partition } - if i.rows2 == nil { + if i.rows == nil { rows, err := i.table.PartitionRows(ctx, i.partition) if err != nil { return false } - ri2, ok := rows.(ValueRowIter) + valRowIter, ok := rows.(ValueRowIter) if !ok { return false } - i.rows2 = ri2 + i.rows = valRowIter } - return i.rows2.CanSupport(ctx) + return i.rows.(ValueRowIter).CanSupport(ctx) } func (i *TableRowIter) Close(ctx *Context) error { diff --git a/sql/types/number.go b/sql/types/number.go index 14fbe2a4be..7e34a30ff4 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -731,11 +731,11 @@ func (t NumberTypeImpl_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqlt func (t NumberTypeImpl_) Compare2(a sql.Value, b sql.Value) (int, error) { switch t.baseType { case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint24, sqltypes.Uint32, sqltypes.Uint64: - ca, err := convertValueToUint64(t, a) + ca, err := ConvertValueToUint64(a) if err != nil { return 0, err } - cb, err := convertValueToUint64(t, b) + cb, err := ConvertValueToUint64(b) if err != nil { return 0, err } @@ -765,11 +765,11 @@ func (t NumberTypeImpl_) Compare2(a sql.Value, b sql.Value) (int, error) { } return +1, nil default: - ca, err := ConvertValueToInt64(t, a) + ca, err := ConvertValueToInt64(a) if err != nil { return 0, err } - cb, err := ConvertValueToInt64(t, b) + cb, err := ConvertValueToInt64(b) if err != nil { return 0, err } @@ -1151,91 +1151,91 @@ func convertToInt64(t NumberTypeImpl_, v any, round Round) (int64, sql.ConvertIn } } -func ConvertValueToInt64(t NumberTypeImpl_, v sql.Value) (int64, error) { +func ConvertValueToInt64(v sql.Value) (int64, error) { switch v.Typ { case query.Type_INT8: return int64(values.ReadInt8(v.Val)), nil - case query.Type_INT16: + case sqltypes.Int16: return int64(values.ReadInt16(v.Val)), nil - case query.Type_INT24: + case sqltypes.Int24: return int64(values.ReadInt24(v.Val)), nil - case query.Type_INT32: + case sqltypes.Int32: return int64(values.ReadInt32(v.Val)), nil - case query.Type_INT64: + case sqltypes.Int64: return values.ReadInt64(v.Val), nil - case query.Type_UINT8: + case sqltypes.Uint8: return int64(values.ReadUint8(v.Val)), nil - case query.Type_UINT16: + case sqltypes.Uint16: return int64(values.ReadUint16(v.Val)), nil - case query.Type_UINT24: + case sqltypes.Uint24: return int64(values.ReadUint24(v.Val)), nil - case query.Type_UINT32: + case sqltypes.Uint32: return int64(values.ReadUint32(v.Val)), nil - case query.Type_UINT64: + case sqltypes.Uint64: v := values.ReadUint64(v.Val) if v > math.MaxInt64 { return math.MaxInt64, nil } return int64(v), nil - case query.Type_FLOAT32: + case sqltypes.Float32: v := values.ReadFloat32(v.Val) if v > float32(math.MaxInt64) { return math.MaxInt64, nil - } else if v < float32(math.MinInt64) { + } + if v < float32(math.MinInt64) { return math.MinInt64, nil } return int64(math.Round(float64(v))), nil - case query.Type_FLOAT64: + case sqltypes.Float64: v := values.ReadFloat64(v.Val) if v > float64(math.MaxInt64) { return math.MaxInt64, nil - } else if v < float64(math.MinInt64) { + } + if v < float64(math.MinInt64) { return math.MinInt64, nil } return int64(math.Round(v)), nil - // TODO: add more conversions default: - panic(sql.ErrInvalidBaseType.New(t.baseType.String(), "number")) + return 0, sql.ErrInvalidBaseType.New(v.Typ.String(), "number") } } -func convertValueToUint64(t NumberTypeImpl_, v sql.Value) (uint64, error) { +func ConvertValueToUint64(v sql.Value) (uint64, error) { switch v.Typ { - case query.Type_INT8: + case sqltypes.Int8: return uint64(values.ReadInt8(v.Val)), nil - case query.Type_INT16: + case sqltypes.Int16: return uint64(values.ReadInt16(v.Val)), nil - case query.Type_INT24: + case sqltypes.Int24: return uint64(values.ReadInt24(v.Val)), nil - case query.Type_INT32: + case sqltypes.Int32: return uint64(values.ReadInt32(v.Val)), nil - case query.Type_INT64: + case sqltypes.Int64: return uint64(values.ReadInt64(v.Val)), nil - case query.Type_UINT8: + case sqltypes.Uint8: return uint64(values.ReadUint8(v.Val)), nil - case query.Type_UINT16: + case sqltypes.Uint16: return uint64(values.ReadUint16(v.Val)), nil - case query.Type_UINT24: + case sqltypes.Uint24: return uint64(values.ReadUint24(v.Val)), nil - case query.Type_UINT32: + case sqltypes.Uint32: return uint64(values.ReadUint32(v.Val)), nil - case query.Type_UINT64: + case sqltypes.Uint64: return values.ReadUint64(v.Val), nil - case query.Type_FLOAT32: + case sqltypes.Float32: v := values.ReadFloat32(v.Val) if v >= float32(math.MaxUint64) { return math.MaxUint64, nil } return uint64(math.Round(float64(v))), nil - case query.Type_FLOAT64: + case sqltypes.Float64: v := values.ReadFloat64(v.Val) - if v >= float64(math.MaxUint64) { + if v > float64(math.MaxUint64) { return math.MaxUint64, nil } return uint64(math.Round(v)), nil - // TODO: add more conversions default: - panic(sql.ErrInvalidBaseType.New(t.baseType.String(), "number")) + return 0, sql.ErrInvalidBaseType.New(v.Typ.String(), "number") } } From e88631e10ca94df4e9f6f3e5e242689bac8424f9 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 23 Oct 2025 12:29:50 -0700 Subject: [PATCH 28/59] more refactoring --- server/handler.go | 4 ++-- sql/expression/get_field.go | 5 +---- sql/type.go | 2 +- sql/types/bit.go | 16 ++++++---------- sql/types/number.go | 4 ++-- 5 files changed, 12 insertions(+), 19 deletions(-) diff --git a/server/handler.go b/server/handler.go index 63da5425bc..bf1e054fee 100644 --- a/server/handler.go +++ b/server/handler.go @@ -1131,7 +1131,7 @@ func toSqlHelper(ctx *sql.Context, typ sql.Type, buf *sql.ByteBuffer, val interf return typ.SQL(ctx, nil, val) } ret, err := typ.SQL(ctx, buf.Get(), val) - buf.Grow(ret.Len()) + buf.Grow(ret.Len()) // TODO: shouldn't we check capacity beforehand? return ret, err } @@ -1182,7 +1182,7 @@ func RowValueToSQLValues(ctx *sql.Context, sch sql.Schema, row sql.ValueRow, buf outVals := make([]sqltypes.Value, len(sch)) for i, col := range sch { // TODO: remove this check once all Types implement this - valType, ok := col.Type.(sql.Type2) + valType, ok := col.Type.(sql.ValueType) if !ok { if row[i].IsNull() { outVals[i] = sqltypes.NULL diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index 3571d97ba5..740cad068b 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -25,8 +25,7 @@ import ( // GetField is an expression to get the field of a table. type GetField struct { - fieldType sql.Type - fieldType2 sql.Type2 + fieldType sql.Type // schemaFormatter is the schemaFormatter used to quote field names schemaFormatter sql.SchemaFormatter @@ -58,13 +57,11 @@ func NewGetField(index int, fieldType sql.Type, fieldName string, nullable bool) // NewGetFieldWithTable creates a GetField expression with table name. The table name may be an alias. func NewGetFieldWithTable(index, tableId int, fieldType sql.Type, db, table, fieldName string, nullable bool) *GetField { - fieldType2, _ := fieldType.(sql.Type2) return &GetField{ db: db, table: table, fieldIndex: index, fieldType: fieldType, - fieldType2: fieldType2, name: fieldName, nullable: nullable, exprId: sql.ColumnId(index), diff --git a/sql/type.go b/sql/type.go index da49d1b1f2..7f45664062 100644 --- a/sql/type.go +++ b/sql/type.go @@ -292,7 +292,7 @@ func IsDecimalType(t Type) bool { return ok } -type Type2 interface { +type ValueType interface { Type ToSQLValue(*Context, Value, []byte) (sqltypes.Value, error) } diff --git a/sql/types/bit.go b/sql/types/bit.go index 39d44dc391..f4c4345c9f 100644 --- a/sql/types/bit.go +++ b/sql/types/bit.go @@ -211,27 +211,23 @@ func (t BitType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va return sqltypes.MakeTrusted(sqltypes.Bit, val), nil } -// ToSQLValue implements Type2 interface. +// ToSQLValue implements ValueType interface. func (t BitType_) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { if v.IsNull() { return sqltypes.NULL, nil } + // Trim/Pad result to the appropriate length numBytes := t.numOfBits / 8 if t.numOfBits%8 != 0 { numBytes += 1 } - - if uint8(len(v.Val)) < numBytes { - // already in little endian, so just pad with trailing 0s - for i := uint8(len(v.Val)); i <= t.numOfBits/8; i++ { - v.Val = append(v.Val, 0) - } - } else { - v.Val = v.Val[:numBytes] + for i := uint8(len(v.Val)); i < numBytes; i++ { + v.Val = append(v.Val, 0) } + v.Val = v.Val[:numBytes] - // TODO: for whatever reason, TestTypesOverWire only works when this is a deep copy? + // TODO: for whatever reason TestTypesOverWire only works when this is a deep copy? dest = append(dest, v.Val...) // want the results in big endian for i, j := 0, len(dest)-1; i < j; i, j = i+1, j-1 { diff --git a/sql/types/number.go b/sql/types/number.go index 7e34a30ff4..762712994f 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -102,7 +102,7 @@ type NumberTypeImpl_ struct { } var _ sql.Type = NumberTypeImpl_{} -var _ sql.Type2 = NumberTypeImpl_{} +var _ sql.ValueType = NumberTypeImpl_{} var _ sql.CollationCoercible = NumberTypeImpl_{} var _ sql.NumberType = NumberTypeImpl_{} var _ sql.RoundingNumberType = NumberTypeImpl_{} @@ -867,7 +867,7 @@ func (t NumberTypeImpl_) Zero2() sql.Value { } } -// ToSQLValue implements Type2 interface. +// ToSQLValue implements ValueType interface. func (t NumberTypeImpl_) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { if v.IsNull() { return sqltypes.NULL, nil From 1464a66fcfcb07a319a456cb94f35c30cbe3f27d Mon Sep 17 00:00:00 2001 From: jycor Date: Thu, 23 Oct 2025 19:32:27 +0000 Subject: [PATCH 29/59] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/expression/comparison.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 06ea5472a4..b442c839e7 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -16,11 +16,12 @@ package expression import ( "fmt" + + "github.com/dolthub/vitess/go/sqltypes" "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" - "github.com/dolthub/vitess/go/sqltypes" ) var ErrInvalidRegexp = errors.NewKind("Invalid regular expression: %s") From ad8119a74ea214e0ab249690b238e8b977ae75a0 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 23 Oct 2025 12:48:39 -0700 Subject: [PATCH 30/59] clean up --- sql/core.go | 2 +- sql/plan/filter.go | 1 + sql/plan/process.go | 2 ++ sql/rowexec/transaction_iters.go | 2 ++ sql/table_iter.go | 1 + 5 files changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/core.go b/sql/core.go index 77b470765a..630ba7bf5b 100644 --- a/sql/core.go +++ b/sql/core.go @@ -465,7 +465,7 @@ type ValueExpression interface { Expression // EvalValue evaluates the given row frame and returns a result. EvalValue(ctx *Context, row ValueRow) (Value, error) - // CanSupport indicates whether this expression and all it's children support ValueExpression. + // CanSupport indicates whether this expression and all its children support ValueExpression. CanSupport() bool } diff --git a/sql/plan/filter.go b/sql/plan/filter.go index bd70107779..b1252e8dc0 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -136,6 +136,7 @@ func (i *FilterIter) Next(ctx *sql.Context) (sql.Row, error) { } } +// NextValueRow implements the sql.ValueRowIter interface. func (i *FilterIter) NextValueRow(ctx *sql.Context) (sql.ValueRow, error) { for { row, err := i.childIter.(sql.ValueRowIter).NextValueRow(ctx) diff --git a/sql/plan/process.go b/sql/plan/process.go index 2b23d4aa6d..7d676c6433 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -317,6 +317,7 @@ func (i *TrackedRowIter) Next(ctx *sql.Context) (sql.Row, error) { return row, nil } +// NextValueRow implements the sql.ValueRowIter interface. func (i *TrackedRowIter) NextValueRow(ctx *sql.Context) (sql.ValueRow, error) { row, err := i.iter.(sql.ValueRowIter).NextValueRow(ctx) if err != nil { @@ -329,6 +330,7 @@ func (i *TrackedRowIter) NextValueRow(ctx *sql.Context) (sql.ValueRow, error) { return row, nil } +// CanSupport implements the sql.ValueRowIter interface. func (i *TrackedRowIter) CanSupport(ctx *sql.Context) bool { iter, ok := i.iter.(sql.ValueRowIter) return ok && iter.CanSupport(ctx) diff --git a/sql/rowexec/transaction_iters.go b/sql/rowexec/transaction_iters.go index fda46f31a6..f0987a3ab7 100644 --- a/sql/rowexec/transaction_iters.go +++ b/sql/rowexec/transaction_iters.go @@ -99,10 +99,12 @@ func (t *TransactionCommittingIter) Next(ctx *sql.Context) (sql.Row, error) { return t.childIter.Next(ctx) } +// NextValueRow implements the sql.ValueRowIter interface. func (t *TransactionCommittingIter) NextValueRow(ctx *sql.Context) (sql.ValueRow, error) { return t.childIter.(sql.ValueRowIter).NextValueRow(ctx) } +// CanSupport implements the sql.ValueRowIter interface. func (t *TransactionCommittingIter) CanSupport(ctx *sql.Context) bool { childIter, ok := t.childIter.(sql.ValueRowIter) return ok && childIter.CanSupport(ctx) diff --git a/sql/table_iter.go b/sql/table_iter.go index 56de252d0c..8beb1daf2c 100644 --- a/sql/table_iter.go +++ b/sql/table_iter.go @@ -116,6 +116,7 @@ func (i *TableRowIter) NextValueRow(ctx *Context) (ValueRow, error) { return row, err } +// CanSupport implements the sql.ValueRowIter interface. func (i *TableRowIter) CanSupport(ctx *Context) bool { if i.partition == nil { partition, err := i.partitions.Next(ctx) From 47c8ad6dc5b6f443beb1bdbbb54b41ff6ef7ce98 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 23 Oct 2025 13:11:16 -0700 Subject: [PATCH 31/59] hide under feature flag --- server/handler.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/server/handler.go b/server/handler.go index bf1e054fee..fdb2862bd0 100644 --- a/server/handler.go +++ b/server/handler.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "net" + "os" "regexp" "runtime/debug" "runtime/trace" @@ -67,6 +68,8 @@ const ( MultiStmtModeOn MultiStmtMode = 1 ) +var enableRowValue = os.Getenv("DOLT_EXPERIMENTAL_VALUE_ROW") != "" + // Handler is a connection handler for a SQLe engine, implementing the Vitess mysql.Handler interface. type Handler struct { sel ServerEventListener @@ -495,7 +498,7 @@ func (h *Handler) doQuery( r, err = resultForEmptyIter(sqlCtx, rowIter, resultFields) } else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) { r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf) - } else if vr, ok := rowIter.(sql.ValueRowIter); ok && vr.CanSupport(sqlCtx) { + } else if vr, ok := rowIter.(sql.ValueRowIter); enableRowValue && ok && vr.CanSupport(sqlCtx) { r, processedAtLeastOneBatch, err = h.resultForValueRowIter(sqlCtx, c, schema, vr, resultFields, buf, callback, more) } else { r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf) From 5e3aeb58f0f73cd75a5cdff73f19b228a56d9d11 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 23 Oct 2025 13:30:43 -0700 Subject: [PATCH 32/59] no feature flag --- server/handler.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/server/handler.go b/server/handler.go index fdb2862bd0..bf1e054fee 100644 --- a/server/handler.go +++ b/server/handler.go @@ -21,7 +21,6 @@ import ( "fmt" "io" "net" - "os" "regexp" "runtime/debug" "runtime/trace" @@ -68,8 +67,6 @@ const ( MultiStmtModeOn MultiStmtMode = 1 ) -var enableRowValue = os.Getenv("DOLT_EXPERIMENTAL_VALUE_ROW") != "" - // Handler is a connection handler for a SQLe engine, implementing the Vitess mysql.Handler interface. type Handler struct { sel ServerEventListener @@ -498,7 +495,7 @@ func (h *Handler) doQuery( r, err = resultForEmptyIter(sqlCtx, rowIter, resultFields) } else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) { r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf) - } else if vr, ok := rowIter.(sql.ValueRowIter); enableRowValue && ok && vr.CanSupport(sqlCtx) { + } else if vr, ok := rowIter.(sql.ValueRowIter); ok && vr.CanSupport(sqlCtx) { r, processedAtLeastOneBatch, err = h.resultForValueRowIter(sqlCtx, c, schema, vr, resultFields, buf, callback, more) } else { r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf) From 2c201b1942e26e91d596769cad21bf78cd454c4c Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 23 Oct 2025 13:50:57 -0700 Subject: [PATCH 33/59] reset memory --- server/handler.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/server/handler.go b/server/handler.go index bf1e054fee..cc3d5f4509 100644 --- a/server/handler.go +++ b/server/handler.go @@ -807,6 +807,17 @@ func (h *Handler) resultForValueRowIter(ctx *sql.Context, c *mysql.Conn, schema wg := sync.WaitGroup{} wg.Add(2) + // Wrap the callback to include a BytesBuffer.Reset() for non-cursor requests, to + // clean out rows that have already been spooled. + resetCallback := func(r *sqltypes.Result, more bool) error { + // A server-side cursor allows the caller to fetch results cached on the server-side, + // so if a cursor exists, we can't release the buffer memory yet. + if c.StatusFlags&uint16(mysql.ServerCursorExists) != 0 { + defer buf.Reset() + } + return callback(r, more) + } + // TODO: send results instead of rows? // Read rows from iter and send them off var rowChan = make(chan sql.ValueRow, 512) @@ -849,7 +860,7 @@ func (h *Handler) resultForValueRowIter(ctx *sql.Context, c *mysql.Conn, schema } } if res.RowsAffected == rowsBatch { - if err := callback(res, more); err != nil { + if err := resetCallback(res, more); err != nil { return err } res = nil From 86202f97cd12dd63f9f560fb6edee186bcc5440d Mon Sep 17 00:00:00 2001 From: James Cor Date: Sun, 26 Oct 2025 18:00:40 -0700 Subject: [PATCH 34/59] feedback --- server/handler.go | 11 ++++------- sql/cache.go | 16 ++++++++-------- sql/core.go | 4 ++-- sql/expression/comparison.go | 8 ++++---- sql/expression/get_field.go | 4 ++-- sql/expression/literal.go | 4 ++-- sql/expression/sort.go | 8 ++++---- sql/expression/unresolved.go | 6 +++--- sql/memory.go | 4 ++-- sql/plan/filter.go | 6 +++--- sql/plan/indexed_table_access.go | 12 ++++++------ sql/plan/process.go | 4 ++-- sql/planbuilder/parse_old_test.go | 26 +++++++++++++------------- sql/rowexec/transaction_iters.go | 4 ++-- sql/rows.go | 10 +++++++--- sql/sort_field.go | 12 ++++++------ sql/table_iter.go | 4 ++-- sql/type.go | 14 +++++++++----- sql/types/bit.go | 2 +- sql/types/datetime.go | 2 +- sql/types/decimal.go | 2 +- sql/types/enum.go | 3 +-- sql/types/number.go | 2 +- sql/types/set.go | 2 +- sql/types/strings.go | 2 +- sql/types/time.go | 2 +- sql/types/year.go | 2 +- sql/{row_frame.go => value_row.go} | 8 ++++---- 28 files changed, 94 insertions(+), 90 deletions(-) rename sql/{row_frame.go => value_row.go} (92%) diff --git a/server/handler.go b/server/handler.go index cc3d5f4509..0cdcea2f57 100644 --- a/server/handler.go +++ b/server/handler.go @@ -495,7 +495,7 @@ func (h *Handler) doQuery( r, err = resultForEmptyIter(sqlCtx, rowIter, resultFields) } else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) { r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf) - } else if vr, ok := rowIter.(sql.ValueRowIter); ok && vr.CanSupport(sqlCtx) { + } else if vr, ok := rowIter.(sql.ValueRowIter); ok && vr.IsValueRowIter(sqlCtx) { r, processedAtLeastOneBatch, err = h.resultForValueRowIter(sqlCtx, c, schema, vr, resultFields, buf, callback, more) } else { r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf) @@ -702,10 +702,7 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s defer wg.Done() for { if r == nil { - r = &sqltypes.Result{ - Rows: make([][]sqltypes.Value, 0, rowsBatch), - Fields: resultFields, - } + r = &sqltypes.Result{Fields: resultFields} } if r.RowsAffected == rowsBatch { if err := resetCallback(r, more); err != nil { @@ -1203,13 +1200,13 @@ func RowValueToSQLValues(ctx *sql.Context, sch sql.Schema, row sql.ValueRow, buf continue } if buf == nil { - outVals[i], err = valType.ToSQLValue(ctx, row[i], nil) + outVals[i], err = valType.SQLValue(ctx, row[i], nil) if err != nil { return nil, err } continue } - outVals[i], err = valType.ToSQLValue(ctx, row[i], buf.Get()) + outVals[i], err = valType.SQLValue(ctx, row[i], buf.Get()) if err != nil { return nil, err } diff --git a/sql/cache.go b/sql/cache.go index b2984eac24..ec772a0e9d 100644 --- a/sql/cache.go +++ b/sql/cache.go @@ -71,10 +71,10 @@ func (l *lruCache) Dispose() { } type rowsCache struct { - memory Freeable - reporter Reporter - rows []Row - rows2 []ValueRow + memory Freeable + reporter Reporter + rows []Row + valueRows []ValueRow } func newRowsCache(memory Freeable, r Reporter) *rowsCache { @@ -92,17 +92,17 @@ func (c *rowsCache) Add(row Row) error { func (c *rowsCache) Get() []Row { return c.rows } -func (c *rowsCache) Add2(row2 ValueRow) error { +func (c *rowsCache) AddValueRow(row ValueRow) error { if !releaseMemoryIfNeeded(c.reporter, c.memory.Free) { return ErrNoMemoryAvailable.New() } - c.rows2 = append(c.rows2, row2) + c.valueRows = append(c.valueRows, row) return nil } -func (c *rowsCache) Get2() []ValueRow { - return c.rows2 +func (c *rowsCache) GetValueRow() []ValueRow { + return c.valueRows } func (c *rowsCache) Dispose() { diff --git a/sql/core.go b/sql/core.go index 630ba7bf5b..44100f02c4 100644 --- a/sql/core.go +++ b/sql/core.go @@ -465,8 +465,8 @@ type ValueExpression interface { Expression // EvalValue evaluates the given row frame and returns a result. EvalValue(ctx *Context, row ValueRow) (Value, error) - // CanSupport indicates whether this expression and all its children support ValueExpression. - CanSupport() bool + // IsValueExpression indicates whether this expression and all its children support ValueExpression. + IsValueExpression() bool } var SystemVariables SystemVariableRegistry diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index b442c839e7..6703f70e95 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -569,20 +569,20 @@ func (gt *GreaterThan) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, return res, nil } -// CanSupport implements the ValueExpression interface. -func (gt *GreaterThan) CanSupport() bool { +// IsValueRowIter implements the ValueExpression interface. +func (gt *GreaterThan) IsValueExpression() bool { l, ok := gt.comparison.LeftChild.(sql.ValueExpression) if !ok { return false } - if !l.CanSupport() { + if !l.IsValueExpression() { return false } r, ok := gt.comparison.RightChild.(sql.ValueExpression) if !ok { return false } - if !r.CanSupport() { + if !r.IsValueExpression() { return false } return true diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index 740cad068b..ec3d5115ac 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -149,8 +149,8 @@ func (p *GetField) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, err return row[p.fieldIndex], nil } -// CanSupport implements the ValueExpression interface. -func (p *GetField) CanSupport() bool { +// IsValueRowIter implements the ValueExpression interface. +func (p *GetField) IsValueExpression() bool { return true } diff --git a/sql/expression/literal.go b/sql/expression/literal.go index 0a1122a429..f48027dae3 100644 --- a/sql/expression/literal.go +++ b/sql/expression/literal.go @@ -141,8 +141,8 @@ func (lit *Literal) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, er return lit.val2, nil } -// CanSupport implements the ValueExpression interface. -func (lit *Literal) CanSupport() bool { +// IsValueRowIter implements the ValueExpression interface. +func (lit *Literal) IsValueExpression() bool { return types.IsInteger(lit.Typ) } diff --git a/sql/expression/sort.go b/sql/expression/sort.go index ab283aad98..6b75031677 100644 --- a/sql/expression/sort.go +++ b/sql/expression/sort.go @@ -86,19 +86,19 @@ func (s *Sorter) Less(i, j int) bool { return false } -// Sorter2 is a version of Sorter that operates on ValueRow -type Sorter2 struct { +// ValueRowSorter is a version of Sorter that operates on ValueRow +type ValueRowSorter struct { LastError error Ctx *sql.Context SortFields []sql.SortField Rows []sql.ValueRow } -func (s *Sorter2) Len() int { +func (s *ValueRowSorter) Len() int { return len(s.Rows) } -func (s *Sorter2) Swap(i, j int) { +func (s *ValueRowSorter) Swap(i, j int) { s.Rows[i], s.Rows[j] = s.Rows[j], s.Rows[i] } diff --git a/sql/expression/unresolved.go b/sql/expression/unresolved.go index b300f4c1a8..a18c7ccf73 100644 --- a/sql/expression/unresolved.go +++ b/sql/expression/unresolved.go @@ -76,9 +76,9 @@ func (uc *UnresolvedColumn) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.V panic("unresolved column is a placeholder node, but EvalValue was called") } -// CanSupport implements the ValueExpression interface. -func (uc *UnresolvedColumn) CanSupport() bool { - panic("unresolved column is a placeholder node, but CanSupport was called") +// IsValueRowIter implements the ValueExpression interface. +func (uc *UnresolvedColumn) IsValueExpression() bool { + panic("unresolved column is a placeholder node, but IsValueExpression was called") } // Name implements the Nameable interface. diff --git a/sql/memory.go b/sql/memory.go index d3ad793942..fd0eff631c 100644 --- a/sql/memory.go +++ b/sql/memory.go @@ -70,9 +70,9 @@ type Rows2Cache interface { // Add2 a new row to the cache. If there is no memory available, it will try to // free some memory. If after that there is still no memory available, it // will return an error and erase all the content of the cache. - Add2(ValueRow) error + AddValueRow(ValueRow) error // Get2 gets all rows. - Get2() []ValueRow + GetValueRow() []ValueRow } // ErrNoMemoryAvailable is returned when there is no more available memory. diff --git a/sql/plan/filter.go b/sql/plan/filter.go index b1252e8dc0..36d28e490e 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -154,13 +154,13 @@ func (i *FilterIter) NextValueRow(ctx *sql.Context) (sql.ValueRow, error) { } // CanSupport implements the sql.ValueRowIter interface. -func (i *FilterIter) CanSupport(ctx *sql.Context) bool { +func (i *FilterIter) IsValueRowIter(ctx *sql.Context) bool { cond, ok := i.cond.(sql.ValueExpression) - if !ok || !cond.CanSupport() { + if !ok || !cond.IsValueExpression() { return false } childIter, ok := i.childIter.(sql.ValueRowIter) - if !ok || !childIter.CanSupport(ctx) { + if !ok || !childIter.IsValueRowIter(ctx) { return false } return true diff --git a/sql/plan/indexed_table_access.go b/sql/plan/indexed_table_access.go index 8afbdf7dd9..d728a7a1c4 100644 --- a/sql/plan/indexed_table_access.go +++ b/sql/plan/indexed_table_access.go @@ -313,7 +313,7 @@ func (i *IndexedTableAccess) getLookup2(ctx *sql.Context, row sql.ValueRow) (sql return i.lookup, nil } - key, err := i.lb.GetKey2(ctx, row) + key, err := i.lb.GetValueRowKey(ctx, row) if err != nil { return sql.IndexLookup{}, err } @@ -500,9 +500,9 @@ type lookupBuilderKey []interface{} // IndexedTableAccess nodes below an indexed join, for example. This struct is // also used to implement Expressioner on the IndexedTableAccess node. type LookupBuilder struct { - index sql.Index - keyExprs []sql.Expression - keyExprs2 []sql.ValueExpression + index sql.Index + keyExprs []sql.Expression + keyValExprs []sql.ValueExpression // When building the lookup, we will use an MySQLIndexBuilder. If the // extracted lookup value is NULL, but we have a non-NULL safe // comparison, then the lookup should return no values. But if the @@ -636,13 +636,13 @@ func (lb *LookupBuilder) GetKey(ctx *sql.Context, row sql.Row) (lookupBuilderKey return lb.key, nil } -func (lb *LookupBuilder) GetKey2(ctx *sql.Context, row sql.ValueRow) (lookupBuilderKey, error) { +func (lb *LookupBuilder) GetValueRowKey(ctx *sql.Context, row sql.ValueRow) (lookupBuilderKey, error) { if lb.key == nil { lb.key = make([]interface{}, len(lb.keyExprs)) } for i := range lb.keyExprs { var err error - lb.key[i], err = lb.keyExprs2[i].EvalValue(ctx, row) + lb.key[i], err = lb.keyValExprs[i].EvalValue(ctx, row) if err != nil { return nil, err } diff --git a/sql/plan/process.go b/sql/plan/process.go index 7d676c6433..210e32f540 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -331,9 +331,9 @@ func (i *TrackedRowIter) NextValueRow(ctx *sql.Context) (sql.ValueRow, error) { } // CanSupport implements the sql.ValueRowIter interface. -func (i *TrackedRowIter) CanSupport(ctx *sql.Context) bool { +func (i *TrackedRowIter) IsValueRowIter(ctx *sql.Context) bool { iter, ok := i.iter.(sql.ValueRowIter) - return ok && iter.CanSupport(ctx) + return ok && iter.IsValueRowIter(ctx) } func (i *TrackedRowIter) Close(ctx *sql.Context) error { diff --git a/sql/planbuilder/parse_old_test.go b/sql/planbuilder/parse_old_test.go index 01c79e9e78..66906fea55 100644 --- a/sql/planbuilder/parse_old_test.go +++ b/sql/planbuilder/parse_old_test.go @@ -1805,7 +1805,7 @@ func TestParse(t *testing.T) { // []sql.SortField{ // { // Column: expression.NewUnresolvedColumn("baz"), - // Column2: expression.NewUnresolvedColumn("baz"), + // ValueExprColumn: expression.NewUnresolvedColumn("baz"), // Order: sql.Descending, // NullOrdering: sql.NullsFirst, // }, @@ -1844,7 +1844,7 @@ func TestParse(t *testing.T) { // []sql.SortField{ // { // Column: expression.NewUnresolvedColumn("baz"), - // Column2: expression.NewUnresolvedColumn("baz"), + // ValueExprColumn: expression.NewUnresolvedColumn("baz"), // Order: sql.Descending, // NullOrdering: sql.NullsFirst, // }, @@ -1866,7 +1866,7 @@ func TestParse(t *testing.T) { // []sql.SortField{ // { // Column: expression.NewUnresolvedColumn("baz"), - // Column2: expression.NewUnresolvedColumn("baz"), + // ValueExprColumn: expression.NewUnresolvedColumn("baz"), // Order: sql.Descending, // NullOrdering: sql.NullsFirst, // }, @@ -2634,13 +2634,13 @@ func TestParse(t *testing.T) { // []sql.SortField{ // { // Column: expression.NewLiteral(int8(2), types.Int8), - // Column2: expression.NewLiteral(int8(2), types.Int8), + // ValueExprColumn: expression.NewLiteral(int8(2), types.Int8), // Order: sql.Ascending, // NullOrdering: sql.NullsFirst, // }, // { // Column: expression.NewLiteral(int8(1), types.Int8), - // Column2: expression.NewLiteral(int8(1), types.Int8), + // ValueExprColumn: expression.NewLiteral(int8(1), types.Int8), // Order: sql.Ascending, // NullOrdering: sql.NullsFirst, // }, @@ -3863,7 +3863,7 @@ func TestParse(t *testing.T) { // }, sql.SortFields{ // { // Column: expression.NewUnresolvedColumn("x"), - // Column2: expression.NewUnresolvedColumn("x"), + // ValueExprColumn: expression.NewUnresolvedColumn("x"), // Order: sql.Ascending, // NullOrdering: sql.NullsFirst, // }, @@ -3896,7 +3896,7 @@ func TestParse(t *testing.T) { // expression.NewUnresolvedFunction("row_number", true, sql.NewWindowDefinition([]sql.Expression{}, sql.SortFields{ // { // Column: expression.NewUnresolvedColumn("x"), - // Column2: expression.NewUnresolvedColumn("x"), + // ValueExprColumn: expression.NewUnresolvedColumn("x"), // Order: sql.Ascending, // NullOrdering: sql.NullsFirst, // }, @@ -3920,7 +3920,7 @@ func TestParse(t *testing.T) { // expression.NewUnresolvedFunction("row_number", true, sql.NewWindowDefinition([]sql.Expression{}, sql.SortFields{ // { // Column: expression.NewUnresolvedColumn("x"), - // Column2: expression.NewUnresolvedColumn("x"), + // ValueExprColumn: expression.NewUnresolvedColumn("x"), // Order: sql.Ascending, // NullOrdering: sql.NullsFirst, // }, @@ -3986,7 +3986,7 @@ func TestParse(t *testing.T) { // expression.NewUnresolvedFunction("count", true, sql.NewWindowDefinition([]sql.Expression{}, sql.SortFields{ // { // Column: expression.NewUnresolvedColumn("x"), - // Column2: expression.NewUnresolvedColumn("x"), + // ValueExprColumn: expression.NewUnresolvedColumn("x"), // Order: sql.Ascending, // NullOrdering: sql.NullsFirst, // }, @@ -4023,7 +4023,7 @@ func TestParse(t *testing.T) { // expression.NewUnresolvedFunction("row_number", true, sql.NewWindowDefinition([]sql.Expression{}, sql.SortFields{ // { // Column: expression.NewUnresolvedColumn("a"), - // Column2: expression.NewUnresolvedColumn("a"), + // ValueExprColumn: expression.NewUnresolvedColumn("a"), // Order: sql.Ascending, // NullOrdering: sql.NullsFirst, // }, @@ -4299,7 +4299,7 @@ func TestParse(t *testing.T) { // "w1": sql.NewWindowDefinition([]sql.Expression{}, sql.SortFields{ // { // Column: expression.NewUnresolvedColumn("x"), - // Column2: expression.NewUnresolvedColumn("x"), + // ValueExprColumn: expression.NewUnresolvedColumn("x"), // Order: sql.Ascending, // NullOrdering: sql.NullsFirst, // }, @@ -4329,7 +4329,7 @@ func TestParse(t *testing.T) { // "w1": sql.NewWindowDefinition([]sql.Expression{}, sql.SortFields{ // { // Column: expression.NewUnresolvedColumn("x"), - // Column2: expression.NewUnresolvedColumn("x"), + // ValueExprColumn: expression.NewUnresolvedColumn("x"), // Order: sql.Ascending, // NullOrdering: sql.NullsFirst, // }, @@ -4838,7 +4838,7 @@ func TestParse(t *testing.T) { // ), true, nil, nil, []sql.SortField{ // { // Column: expression.NewLiteral(int8(2), types.Int8), - // Column2: expression.NewLiteral(int8(2), types.Int8), + // ValueExprColumn: expression.NewLiteral(int8(2), types.Int8), // Order: sql.Ascending, // NullOrdering: sql.NullsFirst, // }, diff --git a/sql/rowexec/transaction_iters.go b/sql/rowexec/transaction_iters.go index f0987a3ab7..034853fb33 100644 --- a/sql/rowexec/transaction_iters.go +++ b/sql/rowexec/transaction_iters.go @@ -105,9 +105,9 @@ func (t *TransactionCommittingIter) NextValueRow(ctx *sql.Context) (sql.ValueRow } // CanSupport implements the sql.ValueRowIter interface. -func (t *TransactionCommittingIter) CanSupport(ctx *sql.Context) bool { +func (t *TransactionCommittingIter) IsValueRowIter(ctx *sql.Context) bool { childIter, ok := t.childIter.(sql.ValueRowIter) - return ok && childIter.CanSupport(ctx) + return ok && childIter.IsValueRowIter(ctx) } func (t *TransactionCommittingIter) Close(ctx *sql.Context) error { diff --git a/sql/rows.go b/sql/rows.go index 04fa261316..061727b7f3 100644 --- a/sql/rows.go +++ b/sql/rows.go @@ -83,15 +83,19 @@ func FormatRow(row Row) string { // TODO: most row iters need to be Disposable for CachedResult safety type RowIter interface { // Next retrieves the next row. It will return io.EOF if it's the last row. - // After retrieving the last row, Close will be automatically closed. + // After retrieving the last row, Close will be automatically called. Next(ctx *Context) (Row, error) Closer } +// ValueRowIter is an iterator that produces sql.ValueRows. type ValueRowIter interface { - RowIter + // NextValueRow retrieves the next ValueRow. It will return io.EOF if it's the last ValueRow. + // After retrieving the last ValueRow, Close will be automatically called. NextValueRow(ctx *Context) (ValueRow, error) - CanSupport(ctx *Context) bool + // IsValueRowIter checks whether this implementor and all its children support ValueRowIter. + IsValueRowIter(ctx *Context) bool + Closer } // RowIterToRows converts a row iterator to a slice of rows. diff --git a/sql/sort_field.go b/sql/sort_field.go index 950cf7074a..9cbcbbaef1 100644 --- a/sql/sort_field.go +++ b/sql/sort_field.go @@ -25,7 +25,7 @@ type SortField struct { // Column to order by. Column Expression // Column ValueExpression to order by. This is always the same value as Column, but avoids a type cast - Column2 ValueExpression + ValueExprColumn ValueExpression // Order type. Order SortOrder // NullOrdering defining how nulls will be ordered. @@ -50,12 +50,12 @@ func (sf SortFields) FromExpressions(exprs ...Expression) SortFields { } for i, expr := range exprs { - expr2, _ := expr.(ValueExpression) + valueExpr, _ := expr.(ValueExpression) fields[i] = SortField{ - Column: expr, - Column2: expr2, - NullOrdering: sf[i].NullOrdering, - Order: sf[i].Order, + Column: expr, + ValueExprColumn: valueExpr, + NullOrdering: sf[i].NullOrdering, + Order: sf[i].Order, } } return fields diff --git a/sql/table_iter.go b/sql/table_iter.go index 8beb1daf2c..7650d6f588 100644 --- a/sql/table_iter.go +++ b/sql/table_iter.go @@ -117,7 +117,7 @@ func (i *TableRowIter) NextValueRow(ctx *Context) (ValueRow, error) { } // CanSupport implements the sql.ValueRowIter interface. -func (i *TableRowIter) CanSupport(ctx *Context) bool { +func (i *TableRowIter) IsValueRowIter(ctx *Context) bool { if i.partition == nil { partition, err := i.partitions.Next(ctx) if err != nil { @@ -136,7 +136,7 @@ func (i *TableRowIter) CanSupport(ctx *Context) bool { } i.rows = valRowIter } - return i.rows.(ValueRowIter).CanSupport(ctx) + return i.rows.(ValueRowIter).IsValueRowIter(ctx) } func (i *TableRowIter) Close(ctx *Context) error { diff --git a/sql/type.go b/sql/type.go index 7f45664062..d15e79bfd1 100644 --- a/sql/type.go +++ b/sql/type.go @@ -105,6 +105,15 @@ type Type interface { fmt.Stringer } +// ValueType is an extension of the Type interface, that operates over sql.Values. +type ValueType interface { + Type + // SQLValue returns the sqltypes.Value for the given sql.Value. + // Implementations can optionally use |dest| to append + // serialized data, but should not mutate existing data. + SQLValue(*Context, Value, []byte) (sqltypes.Value, error) +} + // TrimStringToNumberPrefix will remove any white space for s and truncate any trailing non-numeric characters. func TrimStringToNumberPrefix(ctx *Context, s string, isInt bool) string { if isInt { @@ -292,11 +301,6 @@ func IsDecimalType(t Type) bool { return ok } -type ValueType interface { - Type - ToSQLValue(*Context, Value, []byte) (sqltypes.Value, error) -} - // SpatialColumnType is a node that contains a reference to all spatial types. type SpatialColumnType interface { // GetSpatialTypeSRID returns the SRID value for spatial types. diff --git a/sql/types/bit.go b/sql/types/bit.go index f4c4345c9f..a63cba0610 100644 --- a/sql/types/bit.go +++ b/sql/types/bit.go @@ -212,7 +212,7 @@ func (t BitType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va } // ToSQLValue implements ValueType interface. -func (t BitType_) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { +func (t BitType_) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { if v.IsNull() { return sqltypes.NULL, nil } diff --git a/sql/types/datetime.go b/sql/types/datetime.go index 54262009d3..ab346e84a8 100644 --- a/sql/types/datetime.go +++ b/sql/types/datetime.go @@ -475,7 +475,7 @@ func (t datetimeType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltype return sqltypes.MakeTrusted(typ, valBytes), nil } -func (t datetimeType) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { +func (t datetimeType) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { if v.IsNull() { return sqltypes.NULL, nil } diff --git a/sql/types/decimal.go b/sql/types/decimal.go index 8164916a4d..681491c76d 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -330,7 +330,7 @@ func (t DecimalType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltype return sqltypes.MakeTrusted(sqltypes.Decimal, val), nil } -func (t DecimalType_) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { +func (t DecimalType_) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { if v.IsNull() { return sqltypes.NULL, nil } diff --git a/sql/types/enum.go b/sql/types/enum.go index fdbe81d7d5..5de300526f 100644 --- a/sql/types/enum.go +++ b/sql/types/enum.go @@ -269,7 +269,7 @@ func (t EnumType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va return sqltypes.MakeTrusted(sqltypes.Enum, val), nil } -func (t EnumType) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { +func (t EnumType) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { if v.IsNull() { return sqltypes.NULL, nil } @@ -285,7 +285,6 @@ func (t EnumType) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltyp // TODO: write append style encoder res, ok := charset.Encoder().Encode(encodings.StringToBytes(value)) // TODO: use unsafe string to byte if !ok { - // return snippet of the converted value if len(value) > 50 { value = value[:50] } diff --git a/sql/types/number.go b/sql/types/number.go index 762712994f..f6935799a1 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -868,7 +868,7 @@ func (t NumberTypeImpl_) Zero2() sql.Value { } // ToSQLValue implements ValueType interface. -func (t NumberTypeImpl_) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { +func (t NumberTypeImpl_) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { if v.IsNull() { return sqltypes.NULL, nil } diff --git a/sql/types/set.go b/sql/types/set.go index 601e10b85a..e1f92d6faf 100644 --- a/sql/types/set.go +++ b/sql/types/set.go @@ -262,7 +262,7 @@ func (t SetType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Val return sqltypes.MakeTrusted(sqltypes.Set, val), nil } -func (t SetType) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { +func (t SetType) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { if v.IsNull() { return sqltypes.NULL, nil } diff --git a/sql/types/strings.go b/sql/types/strings.go index 067feb1522..a9f3b71ddb 100644 --- a/sql/types/strings.go +++ b/sql/types/strings.go @@ -791,7 +791,7 @@ func (t StringType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes. } // ToSQLValue implements ValueType interface. -func (t StringType) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { +func (t StringType) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { if v.IsNull() { return sqltypes.NULL, nil } diff --git a/sql/types/time.go b/sql/types/time.go index 696d278f08..a360408e78 100644 --- a/sql/types/time.go +++ b/sql/types/time.go @@ -268,7 +268,7 @@ func (t TimespanType_) SQL(_ *sql.Context, dest []byte, v interface{}) (sqltypes return sqltypes.MakeTrusted(sqltypes.Time, val), nil } -func (t TimespanType_) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { +func (t TimespanType_) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { if v.IsNull() { return sqltypes.NULL, nil } diff --git a/sql/types/year.go b/sql/types/year.go index 91ac78e280..c41137c8d6 100644 --- a/sql/types/year.go +++ b/sql/types/year.go @@ -172,7 +172,7 @@ func (t YearType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.V return sqltypes.MakeTrusted(sqltypes.Year, val), nil } -func (t YearType_) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { +func (t YearType_) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { if v.IsNull() { return sqltypes.NULL, nil } diff --git a/sql/row_frame.go b/sql/value_row.go similarity index 92% rename from sql/row_frame.go rename to sql/value_row.go index a648ef7edb..9bf5296da6 100644 --- a/sql/row_frame.go +++ b/sql/value_row.go @@ -79,9 +79,9 @@ func (f *RowFrame) Recycle() { framePool.Put(f) } -// ValueRow returns the underlying row value in this frame. Does not make a deep copy of underlying byte arrays, so +// AsValueRow returns the underlying row value in this frame. Does not make a deep copy of underlying byte arrays, so // further modification to this frame may result in the returned value changing as well. -func (f *RowFrame) Row2() ValueRow { +func (f *RowFrame) AsValueRow() ValueRow { if f == nil { return nil } @@ -96,9 +96,9 @@ func (f *RowFrame) Row2() ValueRow { return rs } -// Row2Copy returns the row in this frame as a deep copy of the underlying byte arrays. Useful when reusing the +// ValueRowCopy returns the row in this frame as a deep copy of the underlying byte arrays. Useful when reusing the // RowFrame object via Clear() -func (f *RowFrame) Row2Copy() ValueRow { +func (f *RowFrame) ValueRowCopy() ValueRow { rs := make(ValueRow, len(f.Values)) // TODO: it would be faster here to just copy the entire value backing array in one pass for i := range f.Values { From c98a84b7702e8c2f2d5498655c1a2baed266f88f Mon Sep 17 00:00:00 2001 From: James Cor Date: Sun, 26 Oct 2025 18:10:09 -0700 Subject: [PATCH 35/59] fix --- sql/plan/filter.go | 2 +- sql/table_iter.go | 24 ++++++++++++++++-------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/sql/plan/filter.go b/sql/plan/filter.go index 36d28e490e..4e78b50500 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -153,7 +153,7 @@ func (i *FilterIter) NextValueRow(ctx *sql.Context) (sql.ValueRow, error) { } } -// CanSupport implements the sql.ValueRowIter interface. +// IsValueRowIter implements the sql.ValueRowIter interface. func (i *FilterIter) IsValueRowIter(ctx *sql.Context) bool { cond, ok := i.cond.(sql.ValueExpression) if !ok || !cond.IsValueExpression() { diff --git a/sql/table_iter.go b/sql/table_iter.go index 7650d6f588..8df5058274 100644 --- a/sql/table_iter.go +++ b/sql/table_iter.go @@ -24,6 +24,7 @@ type TableRowIter struct { partitions PartitionIter partition Partition rows RowIter + valueRows ValueRowIter } var _ RowIter = (*TableRowIter)(nil) @@ -83,6 +84,7 @@ func (i *TableRowIter) NextValueRow(ctx *Context) (ValueRow, error) { return nil, ctx.Err() default: } + if i.partition == nil { partition, err := i.partitions.Next(ctx) if err != nil { @@ -96,27 +98,27 @@ func (i *TableRowIter) NextValueRow(ctx *Context) (ValueRow, error) { i.partition = partition } - if i.rows == nil { + if i.valueRows == nil { rows, err := i.table.PartitionRows(ctx, i.partition) if err != nil { return nil, err } - i.rows = rows + i.valueRows = rows.(ValueRowIter) } - row, err := i.rows.(ValueRowIter).NextValueRow(ctx) + row, err := i.valueRows.NextValueRow(ctx) if err != nil && err == io.EOF { if err = i.rows.Close(ctx); err != nil { return nil, err } i.partition = nil - i.rows = nil + i.valueRows = nil row, err = i.NextValueRow(ctx) } return row, err } -// CanSupport implements the sql.ValueRowIter interface. +// IsValueRowIter implements the sql.ValueRowIter interface. func (i *TableRowIter) IsValueRowIter(ctx *Context) bool { if i.partition == nil { partition, err := i.partitions.Next(ctx) @@ -125,7 +127,7 @@ func (i *TableRowIter) IsValueRowIter(ctx *Context) bool { } i.partition = partition } - if i.rows == nil { + if i.valueRows == nil { rows, err := i.table.PartitionRows(ctx, i.partition) if err != nil { return false @@ -134,9 +136,9 @@ func (i *TableRowIter) IsValueRowIter(ctx *Context) bool { if !ok { return false } - i.rows = valRowIter + i.valueRows = valRowIter } - return i.rows.(ValueRowIter).IsValueRowIter(ctx) + return i.valueRows.IsValueRowIter(ctx) } func (i *TableRowIter) Close(ctx *Context) error { @@ -146,5 +148,11 @@ func (i *TableRowIter) Close(ctx *Context) error { return err } } + if i.valueRows != nil { + if err := i.valueRows.Close(ctx); err != nil { + _ = i.partitions.Close(ctx) + return err + } + } return i.partitions.Close(ctx) } From d3154281dd234e7e79b3cd6677836a4d089e18f3 Mon Sep 17 00:00:00 2001 From: James Cor Date: Sun, 26 Oct 2025 18:11:37 -0700 Subject: [PATCH 36/59] comments --- sql/plan/process.go | 2 +- sql/rowexec/transaction_iters.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/plan/process.go b/sql/plan/process.go index 210e32f540..adcf3cdf68 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -330,7 +330,7 @@ func (i *TrackedRowIter) NextValueRow(ctx *sql.Context) (sql.ValueRow, error) { return row, nil } -// CanSupport implements the sql.ValueRowIter interface. +// IsValueRowIter implements the sql.ValueRowIter interface. func (i *TrackedRowIter) IsValueRowIter(ctx *sql.Context) bool { iter, ok := i.iter.(sql.ValueRowIter) return ok && iter.IsValueRowIter(ctx) diff --git a/sql/rowexec/transaction_iters.go b/sql/rowexec/transaction_iters.go index 034853fb33..99b0041436 100644 --- a/sql/rowexec/transaction_iters.go +++ b/sql/rowexec/transaction_iters.go @@ -104,7 +104,7 @@ func (t *TransactionCommittingIter) NextValueRow(ctx *sql.Context) (sql.ValueRow return t.childIter.(sql.ValueRowIter).NextValueRow(ctx) } -// CanSupport implements the sql.ValueRowIter interface. +// IsValueRowIter implements the sql.ValueRowIter interface. func (t *TransactionCommittingIter) IsValueRowIter(ctx *sql.Context) bool { childIter, ok := t.childIter.(sql.ValueRowIter) return ok && childIter.IsValueRowIter(ctx) From b79f66093e3611dc04e4810a7f6c76ab7e0fc792 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 29 Oct 2025 00:39:54 -0700 Subject: [PATCH 37/59] a ton of conversions --- sql/expression/comparison.go | 191 ++++++++++++++-------- sql/type.go | 3 + sql/types/bit.go | 27 +++- sql/types/conversion.go | 18 +++ sql/types/datetime.go | 151 +++++++++++++++-- sql/types/decimal.go | 101 ++++++++++-- sql/types/enum.go | 1 + sql/types/number.go | 305 ++++++++++++++++++----------------- sql/types/set.go | 1 + sql/types/strings.go | 107 +++++++++++- sql/types/time.go | 1 + sql/types/year.go | 94 +++++++++++ sql/value_row.go | 10 ++ sql/values/encoding.go | 72 ++++++--- 14 files changed, 819 insertions(+), 263 deletions(-) diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 6703f70e95..7672aa7dee 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -17,7 +17,6 @@ package expression import ( "fmt" - "github.com/dolthub/vitess/go/sqltypes" "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" @@ -158,6 +157,59 @@ func (c *comparison) Compare(ctx *sql.Context, row sql.Row) (int, error) { return compareType.Compare(ctx, l, r) } +// CompareValue the two given values using the types of the expressions in the comparison. +func (c *comparison) CompareValue(ctx *sql.Context, row sql.ValueRow) (int, error) { + // TODO: avoid type assertions + lv, err := c.LeftChild.(sql.ValueExpression).EvalValue(ctx, row) + if err != nil { + return 0, err + } + rv, err := c.RightChild.(sql.ValueExpression).EvalValue(ctx, row) + if err != nil { + return 0, err + } + + if lv.IsNull() || rv.IsNull() { + return 0, nil + } + + lTyp, rTyp := c.LeftChild.Type().(sql.ValueType), c.RightChild.Type().(sql.ValueType) + if types.TypesEqual(lTyp, rTyp) { + return lTyp.(sql.ValueType).CompareValue(ctx, lv, rv) + } + + // TODO: enums + + // TODO: sets + + if types.IsNumber(lTyp) || types.IsNumber(rTyp) { + if types.IsUnsigned(lTyp) && types.IsUnsigned(rTyp) { + return types.Uint64.(sql.ValueType).CompareValue(ctx, lv, rv) + } + if types.IsSigned(lTyp) && types.IsSigned(rTyp) { + return types.Int64.(sql.ValueType).CompareValue(ctx, lv, rv) + } + if types.IsDecimal(lTyp) || types.IsDecimal(rTyp) { + return types.InternalDecimalType.(sql.ValueType).CompareValue(ctx, lv, rv) + } + return types.Float64.(sql.ValueType).CompareValue(ctx, lv, rv) + } + return lTyp.CompareValue(ctx, lv, rv) +} + +// IsValueExpression returns whether every child supports sql.ValueExpression +func (c *comparison) IsValueExpression() bool { + l, ok := c.LeftChild.(sql.ValueExpression) + if !ok { + return false + } + r, ok := c.RightChild.(sql.ValueExpression) + if !ok { + return false + } + return l.IsValueExpression() && r.IsValueExpression() +} + func (c *comparison) evalLeftAndRight(ctx *sql.Context, row sql.Row) (interface{}, interface{}, error) { left, err := c.Left().Eval(ctx, row) if err != nil { @@ -523,71 +575,6 @@ func (gt *GreaterThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) return result == 1, nil } -// EvalValue implements the sql.ValueExpression interface. -func (gt *GreaterThan) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { - lv, err := gt.comparison.LeftChild.(sql.ValueExpression).EvalValue(ctx, row) - if err != nil { - return sql.Value{}, err - } - rv, err := gt.comparison.RightChild.(sql.ValueExpression).EvalValue(ctx, row) - if err != nil { - return sql.Value{}, err - } - - // TODO: move this logic into comparison - var cmp byte - if sqltypes.IsUnsigned(lv.Typ) && sqltypes.IsUnsigned(rv.Typ) { - l, cErr := types.ConvertValueToUint64(lv) - if cErr != nil { - return sql.Value{}, cErr - } - r, cErr := types.ConvertValueToUint64(rv) - if cErr != nil { - return sql.Value{}, cErr - } - if l > r { - cmp = 1 - } - } else { - l, cErr := types.ConvertValueToInt64(lv) - if cErr != nil { - return sql.Value{}, cErr - } - r, cErr := types.ConvertValueToInt64(rv) - if cErr != nil { - return sql.Value{}, cErr - } - if l > r { - cmp = 1 - } - } - - res := sql.Value{ - Val: []byte{cmp}, - Typ: sqltypes.Int8, - } - return res, nil -} - -// IsValueRowIter implements the ValueExpression interface. -func (gt *GreaterThan) IsValueExpression() bool { - l, ok := gt.comparison.LeftChild.(sql.ValueExpression) - if !ok { - return false - } - if !l.IsValueExpression() { - return false - } - r, ok := gt.comparison.RightChild.(sql.ValueExpression) - if !ok { - return false - } - if !r.IsValueExpression() { - return false - } - return true -} - // WithChildren implements the Expression interface. func (gt *GreaterThan) WithChildren(children ...sql.Expression) (sql.Expression, error) { if len(children) != 2 { @@ -608,6 +595,23 @@ func (gt *GreaterThan) DebugString() string { return pr.String() } +// EvalValue implements the sql.ValueExpression interface. +func (gt *GreaterThan) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { + cmp, err := gt.CompareValue(ctx, row) + if err != nil { + return sql.NullValue, err + } + if cmp != 1 { + return sql.FalseValue, nil + } + return sql.TrueValue, nil +} + +// IsValueExpression implements the ValueExpression interface. +func (gt *GreaterThan) IsValueExpression() bool { + return gt.comparison.IsValueExpression() +} + // LessThan is a comparison that checks an expression is less than another. type LessThan struct { comparison @@ -633,10 +637,8 @@ func (lt *LessThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { if ErrNilOperand.Is(err) { return nil, nil } - return nil, err } - return result == -1, nil } @@ -660,6 +662,23 @@ func (lt *LessThan) DebugString() string { return pr.String() } +// EvalValue implements the sql.ValueExpression interface. +func (lt *LessThan) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { + cmp, err := lt.CompareValue(ctx, row) + if err != nil { + return sql.NullValue, err + } + if cmp != -1 { + return sql.FalseValue, nil + } + return sql.TrueValue, nil +} + +// IsValueExpression implements the ValueExpression interface. +func (lt *LessThan) IsValueExpression() bool { + return lt.comparison.IsValueExpression() +} + // GreaterThanOrEqual is a comparison that checks an expression is greater or equal to // another. type GreaterThanOrEqual struct { @@ -686,10 +705,8 @@ func (gte *GreaterThanOrEqual) Eval(ctx *sql.Context, row sql.Row) (interface{}, if ErrNilOperand.Is(err) { return nil, nil } - return nil, err } - return result > -1, nil } @@ -713,6 +730,23 @@ func (gte *GreaterThanOrEqual) DebugString() string { return pr.String() } +// EvalValue implements the sql.ValueExpression interface. +func (gte *GreaterThanOrEqual) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { + cmp, err := gte.CompareValue(ctx, row) + if err != nil { + return sql.NullValue, err + } + if cmp == -1 { + return sql.FalseValue, nil + } + return sql.TrueValue, nil +} + +// IsValueExpression implements the ValueExpression interface. +func (gte *GreaterThanOrEqual) IsValueExpression() bool { + return gte.comparison.IsValueExpression() +} + // LessThanOrEqual is a comparison that checks an expression is equal or lower than // another. type LessThanOrEqual struct { @@ -766,6 +800,23 @@ func (lte *LessThanOrEqual) DebugString() string { return pr.String() } +// EvalValue implements the sql.ValueExpression interface. +func (lte *LessThanOrEqual) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { + cmp, err := lte.CompareValue(ctx, row) + if err != nil { + return sql.NullValue, err + } + if cmp == 1 { + return sql.FalseValue, nil + } + return sql.TrueValue, nil +} + +// IsValueExpression implements the ValueExpression interface. +func (lte *LessThanOrEqual) IsValueExpression() bool { + return lte.comparison.IsValueExpression() +} + var ( // ErrUnsupportedInOperand is returned when there is an invalid righthand // operand in an IN operator. diff --git a/sql/type.go b/sql/type.go index d15e79bfd1..3a36255839 100644 --- a/sql/type.go +++ b/sql/type.go @@ -108,6 +108,9 @@ type Type interface { // ValueType is an extension of the Type interface, that operates over sql.Values. type ValueType interface { Type + // CompareValue returns an integer comparing two sql.Values. + // The result will be 0 if a == b, -1 if a < b, and +1 if a > b. + CompareValue(*Context, Value, Value) (int, error) // SQLValue returns the sqltypes.Value for the given sql.Value. // Implementations can optionally use |dest| to append // serialized data, but should not mutate existing data. diff --git a/sql/types/bit.go b/sql/types/bit.go index a63cba0610..4e3719a582 100644 --- a/sql/types/bit.go +++ b/sql/types/bit.go @@ -103,6 +103,31 @@ func (t BitType_) Compare(ctx context.Context, a interface{}, b interface{}) (in return 0, nil } +// CompareValue implements the ValueType interface +func (t BitType_) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error) { + if hasNulls, res := CompareNullValues(a, b); hasNulls { + return res, nil + } + + av, err := ConvertValueToUint64(ctx, a) + if err != nil { + return 0, err + } + bv, err := ConvertValueToUint64(ctx, b) + if err != nil { + return 0, err + } + + switch { + case av < bv: + return -1, nil + case av > bv: + return 1, nil + default: + return 0, nil + } +} + // Convert implements Type interface. func (t BitType_) Convert(ctx context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { if v == nil { @@ -211,7 +236,7 @@ func (t BitType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va return sqltypes.MakeTrusted(sqltypes.Bit, val), nil } -// ToSQLValue implements ValueType interface. +// SQLValue implements ValueType interface. func (t BitType_) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { if v.IsNull() { return sqltypes.NULL, nil diff --git a/sql/types/conversion.go b/sql/types/conversion.go index 2e5ca52748..0006f94cc1 100644 --- a/sql/types/conversion.go +++ b/sql/types/conversion.go @@ -472,6 +472,24 @@ func CompareNulls(a interface{}, b interface{}) (bool, int) { return false, 0 } +// CompareNullValues compares two sql.Values, and returns true if either is null. +// The returned integer represents the ordering, with a rule that states nulls +// as being ordered before non-nulls. +func CompareNullValues(a, b sql.Value) (bool, int) { + aIsNull := a.IsNull() + bIsNull := b.IsNull() + switch { + case aIsNull && bIsNull: + return true, 0 + case aIsNull && !bIsNull: + return false, 1 + case !aIsNull && bIsNull: + return false, -1 + default: + return false, 0 + } +} + // NumColumns returns the number of columns in a type. This is one for all // types, except tuples. func NumColumns(t sql.Type) int { diff --git a/sql/types/datetime.go b/sql/types/datetime.go index ab346e84a8..f6bbb18215 100644 --- a/sql/types/datetime.go +++ b/sql/types/datetime.go @@ -189,6 +189,29 @@ func (t datetimeType) Compare(ctx context.Context, a interface{}, b interface{}) return 0, nil } +// CompareValue implements the ValueType interface +func (t datetimeType) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error) { + if hasNulls, res := CompareNullValues(a, b); hasNulls { + return res, nil + } + at, err := ConvertValueToDatetime(ctx, a) + if err != nil { + return 0, err + } + bt, err := ConvertValueToDatetime(ctx, b) + if err != nil { + return 0, err + } + switch { + case at.Before(bt): + return -1, nil + case at.After(bt): + return 1, nil + default: + return 0, nil + } +} + // Convert implements Type interface. func (t datetimeType) Convert(ctx context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { if v == nil { @@ -274,14 +297,14 @@ func (t datetimeType) ConvertWithoutRangeCheck(ctx context.Context, v interface{ } // TODO: consider not using time.Parse if we want to match MySQL exactly ('2010-06-03 11:22.:.:.:.:' is a valid timestamp) var parsed bool - res, parsed, err = t.parseDatetime(value) + res, parsed, err = parseDatetime(value) if !parsed { return zeroTime, ErrConvertingToTime.New(v) } case time.Time: res = value.UTC() - // For most integer values, we just return an error (but MySQL is more lenient for some of these). A special case - // is zero values, which are important when converting from postgres defaults. + // For most integer values, we just return an error (but MySQL is more lenient for some of these). A special case + // is zero values, which are important when converting from postgres defaults. case int: if value == 0 { return zeroTime, nil @@ -371,7 +394,7 @@ func (t datetimeType) ConvertWithoutRangeCheck(ctx context.Context, v interface{ return res, err } -func (t datetimeType) parseDatetime(value string) (time.Time, bool, error) { +func parseDatetime(value string) (time.Time, bool, error) { if t, err := time.Parse(TimezoneTimestampDatetimeLayout, value); err == nil { return t.UTC(), true, nil } @@ -475,6 +498,7 @@ func (t datetimeType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltype return sqltypes.MakeTrusted(typ, valBytes), nil } +// SQLValue implements the ValueType interface. func (t datetimeType) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { if v.IsNull() { return sqltypes.NULL, nil @@ -482,11 +506,7 @@ func (t datetimeType) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqlt switch t.baseType { case sqltypes.Date: // TODO: move this to values package - x := values.ReadUint32(v.Val) - y := x >> 16 - m := (x & (255 << 8)) >> 8 - d := x & 255 - t := time.Date(int(y), time.Month(m), int(d), 0, 0, 0, 0, time.UTC) + t := values.ReadDatetime(v.Val) dest = t.AppendFormat(dest, sql.DateLayout) case sqltypes.Datetime, sqltypes.Timestamp: @@ -571,3 +591,116 @@ func ValidateTimestamp(t time.Time) interface{} { } return t } + +func ConvertValueToDatetime(ctx *sql.Context, v sql.Value) (time.Time, error) { + switch v.Typ { + case sqltypes.Int8: + x := values.ReadInt8(v.Val) + if x == 0 { + return zeroTime, nil + } + return zeroTime, ErrConvertingToTime.New(x) + case sqltypes.Int16: + x := values.ReadInt16(v.Val) + if x == 0 { + return zeroTime, nil + } + return zeroTime, ErrConvertingToTime.New(x) + case sqltypes.Int24: + x := values.ReadInt24(v.Val) + if x == 0 { + return zeroTime, nil + } + return zeroTime, ErrConvertingToTime.New(x) + case sqltypes.Int32: + x := values.ReadInt32(v.Val) + if x == 0 { + return zeroTime, nil + } + return zeroTime, ErrConvertingToTime.New(x) + case sqltypes.Int64: + x := values.ReadInt64(v.Val) + if x == 0 { + return zeroTime, nil + } + return zeroTime, ErrConvertingToTime.New(x) + case sqltypes.Uint8: + x := values.ReadUint8(v.Val) + if x == 0 { + return zeroTime, nil + } + return zeroTime, ErrConvertingToTime.New(x) + case sqltypes.Uint16: + x := values.ReadUint16(v.Val) + if x == 0 { + return zeroTime, nil + } + return zeroTime, ErrConvertingToTime.New(x) + case sqltypes.Uint24: + x := values.ReadUint24(v.Val) + if x == 0 { + return zeroTime, nil + } + return zeroTime, ErrConvertingToTime.New(x) + case sqltypes.Uint32: + x := values.ReadUint32(v.Val) + if x == 0 { + return zeroTime, nil + } + return zeroTime, ErrConvertingToTime.New(x) + case sqltypes.Uint64: + x := values.ReadUint64(v.Val) + if x == 0 { + return zeroTime, nil + } + return zeroTime, ErrConvertingToTime.New(x) + case sqltypes.Float32: + x := values.ReadFloat32(v.Val) + if x == 0 { + return zeroTime, nil + } + return zeroTime, ErrConvertingToTime.New(x) + case sqltypes.Float64: + x := values.ReadFloat64(v.Val) + if x == 0 { + return zeroTime, nil + } + return zeroTime, ErrConvertingToTime.New(x) + case sqltypes.Decimal: + x := values.ReadDecimal(v.Val) + if x.IsZero() { + return zeroTime, nil + } + return zeroTime, ErrConvertingToTime.New(x) + case sqltypes.Year: + return zeroTime, nil + case sqltypes.Date: + x := values.ReadDate(v.Val) + return x, nil + case sqltypes.Time: + x := values.ReadInt64(v.Val) + if x == 0 { + return zeroTime, nil + } + return zeroTime, ErrConvertingToTime.New(x) + case sqltypes.Datetime, sqltypes.Timestamp: + x := values.ReadDatetime(v.Val) + return x, nil + case sqltypes.Text, sqltypes.Blob: + var err error + if v.Val == nil { + v.Val, err = v.WrappedVal.Unwrap(ctx) + if err != nil { + return zeroTime, err + } + } + val := values.ReadString(v.Val) + res, parsed, err := parseDatetime(val) + if !parsed { + return zeroTime, ErrConvertingToTime.New(v) + } + return res, err + default: + return zeroTime, sql.ErrInvalidBaseType.New(v.Typ.String(), "number") + } +} diff --git a/sql/types/decimal.go b/sql/types/decimal.go index 681491c76d..3e818b3c8d 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -17,14 +17,13 @@ package types import ( "context" "fmt" - "math/big" - "reflect" - "strings" - "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" + "math/big" + "reflect" + "strings" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/values" @@ -139,6 +138,22 @@ func (t DecimalType_) Compare(s context.Context, a interface{}, b interface{}) ( return af.Decimal.Cmp(bf.Decimal), nil } +// CompareValue implements the ValueType interface +func (t DecimalType_) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error) { + if hasNulls, res := CompareNullValues(a, b); hasNulls { + return res, nil + } + aDec, err := ConvertValueToDecimal(ctx, a) + if err != nil { + return 0, err + } + bDec, err := ConvertValueToDecimal(ctx, b) + if err != nil { + return 0, err + } + return aDec.Cmp(bDec), nil +} + // Convert implements Type interface. func (t DecimalType_) Convert(c context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { dec, err := t.ConvertToNullDecimal(v) @@ -334,14 +349,7 @@ func (t DecimalType_) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqlt if v.IsNull() { return sqltypes.NULL, nil } - // TODO: implement values.ReadDecimal - e := values.ReadInt32(v.Val[:values.Int32Size]) - s := values.ReadInt8(v.Val[values.Int32Size : values.Int32Size+values.Int8Size]) - b := big.NewInt(0).SetBytes(v.Val[values.Int32Size+values.Int8Size:]) - if s < 0 { - b = b.Neg(b) - } - d := decimal.NewFromBigInt(b, e) + d := values.ReadDecimal(v.Val) val := AppendAndSliceString(dest, t.DecimalValueStringFixed(d)) return sqltypes.MakeTrusted(sqltypes.Decimal, val), nil } @@ -402,3 +410,72 @@ func (t DecimalType_) DecimalValueStringFixed(v decimal.Decimal) string { return v.StringFixed(v.Exponent() * -1) } } + +// TODO: Should this take in precision and scale? +func ConvertValueToDecimal(ctx *sql.Context, v sql.Value) (decimal.Decimal, error) { + switch v.Typ { + case sqltypes.Int8: + x := values.ReadInt8(v.Val) + return decimal.NewFromInt(int64(x)), nil + case sqltypes.Int16: + x := values.ReadInt16(v.Val) + return decimal.NewFromInt(int64(x)), nil + case sqltypes.Int32: + x := values.ReadInt32(v.Val) + return decimal.NewFromInt(int64(x)), nil + case sqltypes.Int64: + x := values.ReadInt64(v.Val) + return decimal.NewFromInt(x), nil + case sqltypes.Uint8: + x := values.ReadUint8(v.Val) + return decimal.NewFromInt(int64(x)), nil + case sqltypes.Uint16: + x := values.ReadUint16(v.Val) + return decimal.NewFromInt(int64(x)), nil + case sqltypes.Uint32: + x := values.ReadUint32(v.Val) + return decimal.NewFromInt(int64(x)), nil + case sqltypes.Uint64: + x := values.ReadUint64(v.Val) + bi := new(big.Int).SetUint64(x) + return decimal.NewFromBigInt(bi, 0), nil + case sqltypes.Float32: + x := values.ReadFloat32(v.Val) + return decimal.NewFromFloat32(x), nil + case sqltypes.Float64: + x := values.ReadFloat64(v.Val) + return decimal.NewFromFloat(x), nil + case sqltypes.Decimal: + x := values.ReadDecimal(v.Val) + return x, nil + case sqltypes.Year: + x := values.ReadUint16(v.Val) + return decimal.NewFromInt(int64(x)), nil + case sqltypes.Date: + x := values.ReadDate(v.Val) + s := x.UTC().Unix() + return decimal.NewFromInt(s), nil + case sqltypes.Time: + x := values.ReadInt64(v.Val) + return decimal.NewFromInt(x), nil + case sqltypes.Datetime, sqltypes.Timestamp: + x := values.ReadDatetime(v.Val) + return decimal.NewFromInt(x.UTC().Unix()), nil + case sqltypes.Text, sqltypes.Blob: + var err error + if v.Val == nil { + v.Val, err = v.WrappedVal.Unwrap(ctx) + if err != nil { + return decimal.Decimal{}, err + } + } + x := values.ReadString(v.Val) + res, err := decimal.NewFromString(x) + if err != nil { + return decimal.Decimal{}, err + } + return res, nil + default: + return decimal.Decimal{}, ErrConvertingToDecimal.New(v) + } +} diff --git a/sql/types/enum.go b/sql/types/enum.go index 5de300526f..87388209cd 100644 --- a/sql/types/enum.go +++ b/sql/types/enum.go @@ -269,6 +269,7 @@ func (t EnumType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va return sqltypes.MakeTrusted(sqltypes.Enum, val), nil } +// SQLValue implements the ValueType interface. func (t EnumType) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { if v.IsNull() { return sqltypes.NULL, nil diff --git a/sql/types/number.go b/sql/types/number.go index f6935799a1..6fec4c2ce4 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -210,6 +210,67 @@ func (t NumberTypeImpl_) Compare(s context.Context, a interface{}, b interface{} } } +// CompareValue implements the ValueType interface +func (t NumberTypeImpl_) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error) { + if hasNulls, res := CompareNullValues(a, b); hasNulls { + return res, nil + } + + switch t.baseType { + case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint24, sqltypes.Uint32, sqltypes.Uint64: + ca, err := ConvertValueToUint64(ctx, a) + if err != nil { + return 0, err + } + cb, err := ConvertValueToUint64(ctx, b) + if err != nil { + return 0, err + } + + if ca == cb { + return 0, nil + } + if ca < cb { + return -1, nil + } + return +1, nil + case sqltypes.Float32, sqltypes.Float64: + ca, err := convertValueToFloat64(t, a) + if err != nil { + return 0, err + } + cb, err := convertValueToFloat64(t, b) + if err != nil { + return 0, err + } + + if ca == cb { + return 0, nil + } + if ca < cb { + return -1, nil + } + return +1, nil + default: + ca, err := ConvertValueToInt64(ctx, a) + if err != nil { + return 0, err + } + cb, err := ConvertValueToInt64(ctx, b) + if err != nil { + return 0, err + } + + if ca == cb { + return 0, nil + } + if ca < cb { + return -1, nil + } + return +1, nil + } +} + // Convert implements Type interface. func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { var err error @@ -728,146 +789,7 @@ func (t NumberTypeImpl_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqlt return sqltypes.MakeTrusted(t.baseType, val), nil } -func (t NumberTypeImpl_) Compare2(a sql.Value, b sql.Value) (int, error) { - switch t.baseType { - case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint24, sqltypes.Uint32, sqltypes.Uint64: - ca, err := ConvertValueToUint64(a) - if err != nil { - return 0, err - } - cb, err := ConvertValueToUint64(b) - if err != nil { - return 0, err - } - - if ca == cb { - return 0, nil - } - if ca < cb { - return -1, nil - } - return +1, nil - case sqltypes.Float32, sqltypes.Float64: - ca, err := convertValueToFloat64(t, a) - if err != nil { - return 0, err - } - cb, err := convertValueToFloat64(t, b) - if err != nil { - return 0, err - } - - if ca == cb { - return 0, nil - } - if ca < cb { - return -1, nil - } - return +1, nil - default: - ca, err := ConvertValueToInt64(a) - if err != nil { - return 0, err - } - cb, err := ConvertValueToInt64(b) - if err != nil { - return 0, err - } - - if ca == cb { - return 0, nil - } - if ca < cb { - return -1, nil - } - return +1, nil - } -} - -func (t NumberTypeImpl_) Convert2(value sql.Value) (sql.Value, error) { - panic("implement me") -} - -func (t NumberTypeImpl_) Zero2() sql.Value { - switch t.baseType { - case sqltypes.Int8: - x := values.WriteInt8(make([]byte, values.Int8Size), 0) - return sql.Value{ - Typ: query.Type_INT8, - Val: x, - } - case sqltypes.Int16: - x := values.WriteInt16(make([]byte, values.Int16Size), 0) - return sql.Value{ - Typ: query.Type_INT16, - Val: x, - } - case sqltypes.Int24: - x := values.WriteInt24(make([]byte, values.Int24Size), 0) - return sql.Value{ - Typ: query.Type_INT24, - Val: x, - } - case sqltypes.Int32: - x := values.WriteInt32(make([]byte, values.Int32Size), 0) - return sql.Value{ - Typ: query.Type_INT32, - Val: x, - } - case sqltypes.Int64: - x := values.WriteInt64(make([]byte, values.Int64Size), 0) - return sql.Value{ - Typ: query.Type_INT64, - Val: x, - } - case sqltypes.Uint8: - x := values.WriteUint8(make([]byte, values.Uint8Size), 0) - return sql.Value{ - Typ: query.Type_UINT8, - Val: x, - } - case sqltypes.Uint16: - x := values.WriteUint16(make([]byte, values.Uint16Size), 0) - return sql.Value{ - Typ: query.Type_UINT16, - Val: x, - } - case sqltypes.Uint24: - x := values.WriteUint24(make([]byte, values.Uint24Size), 0) - return sql.Value{ - Typ: query.Type_UINT24, - Val: x, - } - case sqltypes.Uint32: - x := values.WriteUint32(make([]byte, values.Uint32Size), 0) - return sql.Value{ - Typ: query.Type_UINT32, - Val: x, - } - case sqltypes.Uint64: - x := values.WriteUint64(make([]byte, values.Uint64Size), 0) - return sql.Value{ - Typ: query.Type_UINT64, - Val: x, - } - case sqltypes.Float32: - x := values.WriteFloat32(make([]byte, values.Float32Size), 0) - return sql.Value{ - Typ: query.Type_FLOAT32, - Val: x, - } - case sqltypes.Float64: - x := values.WriteUint64(make([]byte, values.Uint64Size), 0) - return sql.Value{ - Typ: query.Type_UINT64, - Val: x, - } - default: - panic(sql.ErrInvalidBaseType.New(t.baseType.String(), "number")) - } -} - -// ToSQLValue implements ValueType interface. +// SQLValue implements ValueType interface. func (t NumberTypeImpl_) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { if v.IsNull() { return sqltypes.NULL, nil @@ -1151,9 +1073,9 @@ func convertToInt64(t NumberTypeImpl_, v any, round Round) (int64, sql.ConvertIn } } -func ConvertValueToInt64(v sql.Value) (int64, error) { +func ConvertValueToInt64(ctx *sql.Context, v sql.Value) (int64, error) { switch v.Typ { - case query.Type_INT8: + case sqltypes.Int8: return int64(values.ReadInt8(v.Val)), nil case sqltypes.Int16: return int64(values.ReadInt16(v.Val)), nil @@ -1195,12 +1117,52 @@ func ConvertValueToInt64(v sql.Value) (int64, error) { return math.MinInt64, nil } return int64(math.Round(v)), nil + case sqltypes.Decimal: + v := values.ReadDecimal(v.Val) + if v.GreaterThan(dec_int64_max) { + return math.MaxInt64, nil + } + if v.LessThan(dec_int64_min) { + return math.MinInt64, nil + } + return v.Round(0).IntPart(), nil + case sqltypes.Year: + v := values.ReadUint16(v.Val) + return int64(v), nil + case sqltypes.Date: + v := values.ReadDate(v.Val) + return v.UTC().Unix(), nil + case sqltypes.Time: + v := values.ReadInt64(v.Val) + return v, nil + case sqltypes.Datetime, sqltypes.Timestamp: + v := values.ReadDatetime(v.Val) + return v.UTC().Unix(), nil + case sqltypes.Text, sqltypes.Blob: + var err error + if v.Val == nil { + v.Val, err = v.WrappedVal.Unwrap(ctx) + if err != nil { + return 0, err + } + } + val := values.ReadString(v.Val) + truncStr, didTrunc := TruncateStringToInt(val) + if didTrunc { + err = sql.ErrTruncatedIncorrect.New(v.Typ, val) + } + i, pErr := strconv.ParseInt(truncStr, 10, 64) + if pErr != nil { + return 0, sql.ErrInvalidValue.New(v, v.Typ.String()) + } + return i, err + // TODO: enum, set, json? default: return 0, sql.ErrInvalidBaseType.New(v.Typ.String(), "number") } } -func ConvertValueToUint64(v sql.Value) (uint64, error) { +func ConvertValueToUint64(ctx *sql.Context, v sql.Value) (uint64, error) { switch v.Typ { case sqltypes.Int8: return uint64(values.ReadInt8(v.Val)), nil @@ -1234,6 +1196,59 @@ func ConvertValueToUint64(v sql.Value) (uint64, error) { return math.MaxUint64, nil } return uint64(math.Round(v)), nil + case sqltypes.Decimal: + v := values.ReadDecimal(v.Val) + if v.GreaterThan(dec_uint64_max) { + return math.MaxUint64, nil + } + if v.LessThan(dec_zero) { + ret, _ := dec_uint64_max.Sub(v).Float64() + return uint64(math.Round(ret)), nil + } + return uint64(v.Round(0).IntPart()), nil + case sqltypes.Year: + v := values.ReadUint16(v.Val) + return uint64(v), nil + case sqltypes.Date: + v := values.ReadDate(v.Val) + return uint64(v.UTC().Unix()), nil + case sqltypes.Time: + v := values.ReadInt64(v.Val) + return uint64(v), nil + case sqltypes.Datetime, sqltypes.Timestamp: + v := values.ReadDatetime(v.Val) + return uint64(v.UTC().Unix()), nil + case sqltypes.Text, sqltypes.Blob: + var err error + if v.Val == nil { + v.Val, err = v.WrappedVal.Unwrap(ctx) + if err != nil { + return 0, err + } + } + val := values.ReadString(v.Val) + + truncStr, didTrunc := TruncateStringToInt(val) + if didTrunc { + err = sql.ErrTruncatedIncorrect.New(v.Typ, val) + } + + var neg bool + if truncStr[0] == '+' { + truncStr = truncStr[1:] + } else if truncStr[0] == '-' { + truncStr = truncStr[1:] + neg = true + } + + i, pErr := strconv.ParseUint(truncStr, 10, 64) + if errors.Is(pErr, strconv.ErrRange) { + return math.MaxUint64, nil + } + if neg { + return math.MaxUint64 - i + 1, err + } + return i, err default: return 0, sql.ErrInvalidBaseType.New(v.Typ.String(), "number") } @@ -1245,9 +1260,9 @@ func convertToUint64(t NumberTypeImpl_, v any, round Round) (uint64, sql.Convert return uint64(v.UTC().Unix()), sql.InRange, nil case int: if v < 0 { - return uint64(math.MaxUint64 - uint(-v-1)), sql.OutOfRange, nil + return uint64(v), sql.OutOfRange, nil } - return uint64(v), sql.InRange, nil + return uint64(v), v > 0, nil case int8: if v < 0 { return uint64(math.MaxUint64 - uint(-v-1)), sql.OutOfRange, nil diff --git a/sql/types/set.go b/sql/types/set.go index e1f92d6faf..2ca5a96988 100644 --- a/sql/types/set.go +++ b/sql/types/set.go @@ -262,6 +262,7 @@ func (t SetType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Val return sqltypes.MakeTrusted(sqltypes.Set, val), nil } +// SQLValue implements ValueType interface. func (t SetType) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { if v.IsNull() { return sqltypes.NULL, nil diff --git a/sql/types/strings.go b/sql/types/strings.go index a9f3b71ddb..79ae2b25b8 100644 --- a/sql/types/strings.go +++ b/sql/types/strings.go @@ -17,6 +17,7 @@ package types import ( "context" "fmt" + "github.com/dolthub/go-mysql-server/sql/values" "reflect" "strconv" strings2 "strings" @@ -334,6 +335,110 @@ func (t StringType) Compare(ctx context.Context, a interface{}, b interface{}) ( } } +// CompareValue implements the sql.ValueType interface. +func (t StringType) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error) { + if hasNulls, res := CompareNullValues(a, b); hasNulls { + return res, nil + } + + // TODO: Possible to compare binary strings directly? + as, err := ConvertValueToString(ctx, a) + if err != nil { + return 0, err + } + bs, err := ConvertValueToString(ctx, b) + if err != nil { + return 0, err + } + + encoder := t.collation.CharacterSet().Encoder() + getRuneWeight := t.collation.Sorter() + for len(as) > 0 && len(bs) > 0 { + ar, aRead := encoder.NextRune(as) + br, bRead := encoder.NextRune(bs) + if aRead == 0 || bRead == 0 || aRead == utf8.RuneError || bRead == utf8.RuneError { + return 0, fmt.Errorf("malformed string encountered while comparing") + } + aWeight := getRuneWeight(ar) + bWeight := getRuneWeight(br) + if aWeight < bWeight { + return -1, nil + } + if aWeight > bWeight { + return 1, nil + } + as = as[aRead:] + bs = bs[bRead:] + } + + return 0, nil +} + +func ConvertValueToString(ctx *sql.Context, v sql.Value) (string, error) { + // TODO: fix allocation + switch v.Typ { + case sqltypes.Int8: + x := values.ReadInt8(v.Val) + return strconv.FormatInt(int64(x), 10), nil + case sqltypes.Int16: + x := values.ReadInt16(v.Val) + return strconv.FormatInt(int64(x), 10), nil + case sqltypes.Int24: + x := values.ReadInt24(v.Val) + return strconv.FormatInt(int64(x), 10), nil + case sqltypes.Int32: + x := values.ReadInt32(v.Val) + return strconv.FormatInt(int64(x), 10), nil + case sqltypes.Int64: + x := values.ReadInt64(v.Val) + return strconv.FormatInt(x, 10), nil + case sqltypes.Uint8: + x := values.ReadUint8(v.Val) + return strconv.FormatUint(uint64(x), 10), nil + case sqltypes.Uint16: + x := values.ReadUint16(v.Val) + return strconv.FormatUint(uint64(x), 10), nil + case sqltypes.Uint24: + x := values.ReadUint24(v.Val) + return strconv.FormatUint(uint64(x), 10), nil + case sqltypes.Uint32: + x := values.ReadUint32(v.Val) + return strconv.FormatUint(uint64(x), 10), nil + case sqltypes.Uint64: + x := values.ReadUint64(v.Val) + return strconv.FormatUint(x, 10), nil + case sqltypes.Float32: + x := values.ReadFloat32(v.Val) + return strconv.FormatFloat(float64(x), 'f', -1, 32), nil + case sqltypes.Float64: + x := values.ReadFloat64(v.Val) + return strconv.FormatFloat(x, 'f', -1, 64), nil + case sqltypes.Decimal: + x := values.ReadDecimal(v.Val) + return x.String(), nil + case sqltypes.Year: + x := values.ReadInt16(v.Val) + return strconv.FormatInt(int64(x), 10), nil + case sqltypes.Date: + x := values.ReadDate(v.Val) + return x.String(), nil + case sqltypes.Datetime, sqltypes.Timestamp: + x := values.ReadDatetime(v.Val) + return x.String(), nil + case sqltypes.Text, sqltypes.Blob: + var err error + if v.Val == nil { + v.Val, err = v.WrappedVal.Unwrap(ctx) + if err != nil { + return "", err + } + } + return values.ReadString(v.Val), nil + default: + return "", fmt.Errorf("unsupported type %v", v.Typ) + } +} + // Convert implements Type interface. func (t StringType) Convert(ctx context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { if v == nil { @@ -790,7 +895,7 @@ func (t StringType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes. return sqltypes.MakeTrusted(t.baseType, val), nil } -// ToSQLValue implements ValueType interface. +// SQLValue implements ValueType interface. func (t StringType) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { if v.IsNull() { return sqltypes.NULL, nil diff --git a/sql/types/time.go b/sql/types/time.go index a360408e78..6c7dd3b135 100644 --- a/sql/types/time.go +++ b/sql/types/time.go @@ -268,6 +268,7 @@ func (t TimespanType_) SQL(_ *sql.Context, dest []byte, v interface{}) (sqltypes return sqltypes.MakeTrusted(sqltypes.Time, val), nil } +// SQLValue implements ValueType interface. func (t TimespanType_) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { if v.IsNull() { return sqltypes.NULL, nil diff --git a/sql/types/year.go b/sql/types/year.go index c41137c8d6..8dbc0ba311 100644 --- a/sql/types/year.go +++ b/sql/types/year.go @@ -65,6 +65,29 @@ func (t YearType_) Compare(ctx context.Context, a interface{}, b interface{}) (i return 1, nil } +// CompareValue implements the ValueType interface. +func (t YearType_) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error) { + if hasNulls, res := CompareNullValues(a, b); hasNulls { + return res, nil + } + ay, err := ConvertValueToYear(ctx, a) + if err != nil { + return 0, err + } + by, err := ConvertValueToYear(ctx, b) + if err != nil { + return 0, err + } + switch { + case ay < by: + return -1, nil + case ay > by: + return 1, nil + default: + return 0, nil + } +} + // Convert implements Type interface. func (t YearType_) Convert(ctx context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { if v == nil { @@ -172,6 +195,7 @@ func (t YearType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.V return sqltypes.MakeTrusted(sqltypes.Year, val), nil } +// SQLValue implements ValueType interface. func (t YearType_) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { if v.IsNull() { return sqltypes.NULL, nil @@ -205,3 +229,73 @@ func (t YearType_) Zero() interface{} { func (YearType_) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 5 } + +func ConvertValueToYear(ctx *sql.Context, v sql.Value) (uint16, error) { + switch v.Typ { + case sqltypes.Int8: + x := values.ReadInt8(v.Val) + return uint16(x), nil + case sqltypes.Int16: + x := values.ReadInt16(v.Val) + return uint16(x), nil + case sqltypes.Int32: + x := values.ReadInt32(v.Val) + return uint16(x), nil + case sqltypes.Int64: + x := values.ReadInt64(v.Val) + return uint16(x), nil + case sqltypes.Uint8: + x := values.ReadUint8(v.Val) + return uint16(x), nil + case sqltypes.Uint16: + x := values.ReadUint16(v.Val) + return x, nil + case sqltypes.Uint32: + x := values.ReadUint32(v.Val) + return uint16(x), nil + case sqltypes.Uint64: + x := values.ReadUint64(v.Val) + return uint16(x), nil + case sqltypes.Float32: + x := values.ReadFloat32(v.Val) + return uint16(x), nil + case sqltypes.Float64: + x := values.ReadFloat64(v.Val) + return uint16(x), nil + case sqltypes.Decimal: + x := values.ReadDecimal(v.Val) + return uint16(x.IntPart()), nil + case sqltypes.Year: + x := values.ReadUint16(v.Val) + return x, nil + case sqltypes.Date: + x := values.ReadDate(v.Val) + return uint16(x.UTC().Unix()), nil + case sqltypes.Time: + x := values.ReadInt64(v.Val) + return uint16(x), nil + case sqltypes.Datetime, sqltypes.Timestamp: + x := values.ReadDatetime(v.Val) + return uint16(x.UTC().Unix()), nil + case sqltypes.Text, sqltypes.Blob: + var err error + if v.Val == nil { + v.Val, err = v.WrappedVal.Unwrap(ctx) + if err != nil { + return 0, err + } + } + val := values.ReadString(v.Val) + truncStr, didTrunc := TruncateStringToInt(val) + if didTrunc { + err = sql.ErrTruncatedIncorrect.New(v.Typ, val) + } + i, pErr := strconv.ParseInt(truncStr, 10, 64) + if pErr != nil { + return 0, sql.ErrInvalidValue.New(v, v.Typ.String()) + } + return uint16(i), err + default: + return 0, ErrConvertingToYear.New(v) + } +} diff --git a/sql/value_row.go b/sql/value_row.go index 9bf5296da6..f9140c41c5 100644 --- a/sql/value_row.go +++ b/sql/value_row.go @@ -34,6 +34,16 @@ type Value struct { Typ query.Type } +var NullValue = Value{} +var FalseValue = Value{ + Val: []byte{0}, + Typ: query.Type_INT8, +} +var TrueValue = Value{ + Val: []byte{1}, + Typ: query.Type_INT8, +} + // ValueRow is a slice of values type ValueRow []Value diff --git a/sql/values/encoding.go b/sql/values/encoding.go index 5db8eb3af4..1a526ace76 100644 --- a/sql/values/encoding.go +++ b/sql/values/encoding.go @@ -18,7 +18,11 @@ import ( "bytes" "encoding/binary" "fmt" + "github.com/shopspring/decimal" "math" + "math/big" + "time" + "unsafe" ) type Type struct { @@ -38,15 +42,18 @@ const ( Uint24Size ByteSize = 3 Int32Size ByteSize = 4 Uint32Size ByteSize = 4 - Int48Size ByteSize = 6 - Uint48Size ByteSize = 6 Int64Size ByteSize = 8 Uint64Size ByteSize = 8 Float32Size ByteSize = 4 Float64Size ByteSize = 8 + DecimalSize ByteSize = 5 + + DateSize ByteSize = 4 + TimeSize ByteSize = 8 + DatetimeSize ByteSize = 8 + TimestampSize ByteSize = 8 ) -const maxUint48 = uint64(1<<48 - 1) const maxUint24 = uint32(1<<24 - 1) type Collation uint16 @@ -109,6 +116,7 @@ func ReadBool(val []byte) bool { expectSize(val, Int8Size) return val[0] == 1 } + func ReadInt8(val []byte) int8 { expectSize(val, Int8Size) return int8(val[0]) @@ -163,20 +171,48 @@ func ReadUint64(val []byte) uint64 { func ReadFloat32(val []byte) float32 { expectSize(val, Float32Size) - return math.Float32frombits(ReadUint32(val)) + x := binary.LittleEndian.Uint32(val) + return math.Float32frombits(x) } func ReadFloat64(val []byte) float64 { expectSize(val, Float64Size) - return math.Float64frombits(ReadUint64(val)) + x := binary.LittleEndian.Uint64(val) + return math.Float64frombits(x) +} + +func ReadDecimal(val []byte) decimal.Decimal { + expectSize(val, DecimalSize) + e := ReadInt32(val[:Int32Size]) + s := ReadInt8(val[Int32Size : Int32Size+Int8Size]) + b := big.NewInt(0).SetBytes(val[Int32Size+Int8Size:]) + if s < 0 { + b = b.Neg(b) + } + return decimal.NewFromBigInt(b, e) +} + +func ReadDate(val []byte) time.Time { + expectSize(val, Uint32Size) + x := binary.LittleEndian.Uint32(val) + y := x >> 16 + m := (x & (255 << 8)) >> 8 + d := x & 255 + return time.Date(int(y), time.Month(m), int(d), 0, 0, 0, 0, time.UTC) } -func ReadString(val []byte, coll Collation) string { - // todo: fix allocation - return string(val) +func ReadDatetime(val []byte) time.Time { + expectSize(val, DatetimeSize) + ms := int64(binary.LittleEndian.Uint64(val)) + return time.UnixMicro(ms).UTC() } -func ReadBytes(val []byte, coll Collation) []byte { +func ReadString(val []byte) string { + // TODO: this is essentially encoding.BytesToString + return *(*string)(unsafe.Pointer(&val)) +} + +func ReadBytes(val []byte) []byte { // todo: fix collation return val } @@ -249,20 +285,6 @@ func WriteUint32(buf []byte, val uint32) []byte { return buf } -func WriteUint48(buf []byte, u uint64) []byte { - expectSize(buf, Uint48Size) - if u > maxUint48 { - panic("uint is greater than max uint48") - } - var tmp [8]byte - binary.LittleEndian.PutUint64(tmp[:], u) - // copy |tmp| to |buf| - buf[5], buf[4] = tmp[5], tmp[4] - buf[3], buf[2] = tmp[3], tmp[2] - buf[1], buf[0] = tmp[1], tmp[0] - return buf -} - func WriteInt64(buf []byte, val int64) []byte { expectSize(buf, Int64Size) binary.LittleEndian.PutUint64(buf, uint64(val)) @@ -349,9 +371,9 @@ func compare(typ Type, left, right []byte) int { case Float64Enc: return compareFloat64(ReadFloat64(left), ReadFloat64(right)) case StringEnc: - return compareString(ReadString(left, typ.Coll), ReadString(right, typ.Coll), typ.Coll) + return compareString(ReadString(left), ReadString(right), typ.Coll) case BytesEnc: - return compareBytes(ReadBytes(left, typ.Coll), ReadBytes(right, typ.Coll), typ.Coll) + return compareBytes(ReadBytes(left), ReadBytes(right), typ.Coll) default: panic("unknown encoding") } From a6e326a3fa7f751ec4424c503aefb3f8f23f3142 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 29 Oct 2025 01:00:20 -0700 Subject: [PATCH 38/59] more stuff --- sql/expression/comparison.go | 6 ++++++ sql/expression/literal.go | 2 +- sql/types/datetime.go | 2 +- sql/types/decimal.go | 1 - 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 7672aa7dee..25b4c4e4d7 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -207,6 +207,12 @@ func (c *comparison) IsValueExpression() bool { if !ok { return false } + if _, ok := c.LeftChild.Type().(sql.ValueType); !ok { + return false + } + if _, ok := c.RightChild.Type().(sql.ValueType); !ok { + return false + } return l.IsValueExpression() && r.IsValueExpression() } diff --git a/sql/expression/literal.go b/sql/expression/literal.go index f48027dae3..104c04fd97 100644 --- a/sql/expression/literal.go +++ b/sql/expression/literal.go @@ -141,7 +141,7 @@ func (lit *Literal) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, er return lit.val2, nil } -// IsValueRowIter implements the ValueExpression interface. +// IsValueExpression implements the ValueExpression interface. func (lit *Literal) IsValueExpression() bool { return types.IsInteger(lit.Typ) } diff --git a/sql/types/datetime.go b/sql/types/datetime.go index f6bbb18215..958b154eaa 100644 --- a/sql/types/datetime.go +++ b/sql/types/datetime.go @@ -701,6 +701,6 @@ func ConvertValueToDatetime(ctx *sql.Context, v sql.Value) (time.Time, error) { } return res, err default: - return zeroTime, sql.ErrInvalidBaseType.New(v.Typ.String(), "number") + return zeroTime, sql.ErrInvalidBaseType.New(v.Typ.String(), "datetime") } } diff --git a/sql/types/decimal.go b/sql/types/decimal.go index 3e818b3c8d..14d150aeec 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -411,7 +411,6 @@ func (t DecimalType_) DecimalValueStringFixed(v decimal.Decimal) string { } } -// TODO: Should this take in precision and scale? func ConvertValueToDecimal(ctx *sql.Context, v sql.Value) (decimal.Decimal, error) { switch v.Typ { case sqltypes.Int8: From f0cc6c626c31bc12c7888682a6e204458dad321c Mon Sep 17 00:00:00 2001 From: jycor Date: Wed, 29 Oct 2025 08:02:17 +0000 Subject: [PATCH 39/59] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/types/decimal.go | 7 ++++--- sql/types/strings.go | 2 +- sql/values/encoding.go | 3 ++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/types/decimal.go b/sql/types/decimal.go index 14d150aeec..ebf14a8d3e 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -17,13 +17,14 @@ package types import ( "context" "fmt" + "math/big" + "reflect" + "strings" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" - "math/big" - "reflect" - "strings" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/values" diff --git a/sql/types/strings.go b/sql/types/strings.go index 79ae2b25b8..e1ca2962e1 100644 --- a/sql/types/strings.go +++ b/sql/types/strings.go @@ -17,7 +17,6 @@ package types import ( "context" "fmt" - "github.com/dolthub/go-mysql-server/sql/values" "reflect" "strconv" strings2 "strings" @@ -31,6 +30,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/encodings" + "github.com/dolthub/go-mysql-server/sql/values" ) const ( diff --git a/sql/values/encoding.go b/sql/values/encoding.go index 1a526ace76..361aea2088 100644 --- a/sql/values/encoding.go +++ b/sql/values/encoding.go @@ -18,11 +18,12 @@ import ( "bytes" "encoding/binary" "fmt" - "github.com/shopspring/decimal" "math" "math/big" "time" "unsafe" + + "github.com/shopspring/decimal" ) type Type struct { From 5c592d828456d6b3a248a35228910026e2406b7c Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 29 Oct 2025 01:04:39 -0700 Subject: [PATCH 40/59] remove todo --- sql/expression/comparison.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 25b4c4e4d7..6c01d0a0ad 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -178,10 +178,6 @@ func (c *comparison) CompareValue(ctx *sql.Context, row sql.ValueRow) (int, erro return lTyp.(sql.ValueType).CompareValue(ctx, lv, rv) } - // TODO: enums - - // TODO: sets - if types.IsNumber(lTyp) || types.IsNumber(rTyp) { if types.IsUnsigned(lTyp) && types.IsUnsigned(rTyp) { return types.Uint64.(sql.ValueType).CompareValue(ctx, lv, rv) From 12ca1fe566f1e0131577d121c18bb20cf008a6c1 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 29 Oct 2025 11:48:05 -0700 Subject: [PATCH 41/59] so many types --- sql/expression/comparison_test.go | 5 + sql/types/number.go | 16 +- sql/types/number_test.go | 531 ++++++++++++++++++++++++++++++ sql/values/encoding.go | 1 - 4 files changed, 549 insertions(+), 4 deletions(-) diff --git a/sql/expression/comparison_test.go b/sql/expression/comparison_test.go index a2ce6eeae1..4032e54ab5 100644 --- a/sql/expression/comparison_test.go +++ b/sql/expression/comparison_test.go @@ -224,3 +224,8 @@ func eval(t *testing.T, e sql.Expression, row sql.Row) interface{} { require.NoError(t, err) return v } + +func TestValueComparison(t *testing.T) { + require := require.New(t) + +} diff --git a/sql/types/number.go b/sql/types/number.go index 6fec4c2ce4..05aa5a165e 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -1131,13 +1131,24 @@ func ConvertValueToInt64(ctx *sql.Context, v sql.Value) (int64, error) { return int64(v), nil case sqltypes.Date: v := values.ReadDate(v.Val) - return v.UTC().Unix(), nil + var res int64 + res += int64(v.Year() * 10000) + res += int64(v.Month() * 100) + res += int64(v.Day()) + return res, nil case sqltypes.Time: v := values.ReadInt64(v.Val) return v, nil case sqltypes.Datetime, sqltypes.Timestamp: v := values.ReadDatetime(v.Val) - return v.UTC().Unix(), nil + var res int64 + res += int64(v.Year() * 1_00_00_00_00_00) + res += int64(v.Month() * 1_00_00_00_00) + res += int64(v.Day() * 1_00_00_00) + res += int64(v.Hour() * 1_00_00) + res += int64(v.Minute() * 1_00) + res += int64(v.Second()) + return res, nil case sqltypes.Text, sqltypes.Blob: var err error if v.Val == nil { @@ -1156,7 +1167,6 @@ func ConvertValueToInt64(ctx *sql.Context, v sql.Value) (int64, error) { return 0, sql.ErrInvalidValue.New(v, v.Typ.String()) } return i, err - // TODO: enum, set, json? default: return 0, sql.ErrInvalidBaseType.New(v.Typ.String(), "number") } diff --git a/sql/types/number_test.go b/sql/types/number_test.go index 695f61fcc0..33074e7994 100644 --- a/sql/types/number_test.go +++ b/sql/types/number_test.go @@ -15,7 +15,9 @@ package types import ( + "encoding/binary" "fmt" + "github.com/shopspring/decimal" "math" "reflect" "strconv" @@ -711,3 +713,532 @@ func TestTruncateStringToDouble(t *testing.T) { }) } } + +func serializeDecimal(dec decimal.Decimal) []byte { + var res []byte + coef := dec.Coefficient() + res = binary.LittleEndian.AppendUint32(res, uint32(dec.Exponent())) + res = append(res, byte(coef.Sign())) + res = append(res, coef.Bytes()...) + return res +} + +func serializeDate(date time.Time) []byte { + var res uint32 + res += uint32(date.Year() << 16) + res += uint32(date.Month() << 8) + res += uint32(date.Day()) + return binary.LittleEndian.AppendUint32(nil, res) +} + +func serializeDatetime(date time.Time) []byte { + return binary.LittleEndian.AppendUint64(nil, uint64(date.UnixMicro())) +} + +func TestConvertValueToInt64(t *testing.T) { + ctx := sql.NewEmptyContext() + + zeroDec := serializeDecimal(decimal.Zero) + testDec := serializeDecimal(decimal.NewFromFloat(123.456)) + minInt64Dec := serializeDecimal(decimal.NewFromInt(math.MinInt64)) + maxInt64Dec := serializeDecimal(decimal.NewFromInt(math.MaxInt64)) + + zeroDate := serializeDate(time.Unix(0, 0).UTC()) + testDate := serializeDate(time.Date(2000, 01, 02, 0, 0, 0, 0, time.UTC)) + + zeroDatetime := serializeDatetime(time.Unix(0, 0).UTC()) + testDatetime := serializeDatetime(time.Date(2000, 01, 02, 12, 34, 56, 0, time.UTC)) + + tests := []struct { + val sql.Value + exp int64 + err bool + }{ + // Int8 -> Int64 + { + val: sql.Value{ + Val: []byte{0}, + Typ: sqltypes.Int8, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: []byte{67}, + Typ: sqltypes.Int8, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: []byte{128}, + Typ: sqltypes.Int8, + }, + exp: -128, + }, + { + val: sql.Value{ + Val: []byte{255}, + Typ: sqltypes.Int8, + }, + exp: -1, + }, + + // Int16 -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Int16, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(67)), + Typ: sqltypes.Int16, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16+1), + Typ: sqltypes.Int16, + }, + exp: math.MinInt16, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16), + Typ: sqltypes.Int16, + }, + exp: math.MaxInt16, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxUint16), + Typ: sqltypes.Int16, + }, + exp: -1, + }, + + // Int32 -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(0)), + Typ: sqltypes.Int32, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(67)), + Typ: sqltypes.Int32, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32+1), + Typ: sqltypes.Int32, + }, + exp: math.MinInt32, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32), + Typ: sqltypes.Int32, + }, + exp: math.MaxInt32, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxUint32), + Typ: sqltypes.Int32, + }, + exp: -1, + }, + + // Int64 -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Int64, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Int64, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64+1), + Typ: sqltypes.Int64, + }, + exp: math.MinInt64, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Int64, + }, + exp: math.MaxInt64, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Int64, + }, + exp: -1, + }, + + // Uint8 -> Int64 + { + val: sql.Value{ + Val: []byte{0}, + Typ: sqltypes.Uint8, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: []byte{67}, + Typ: sqltypes.Uint8, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: []byte{128}, + Typ: sqltypes.Uint8, + }, + exp: 128, + }, + { + val: sql.Value{ + Val: []byte{255}, + Typ: sqltypes.Uint8, + }, + exp: 255, + }, + + // Uint16 -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Uint16, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(67)), + Typ: sqltypes.Uint16, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16), + Typ: sqltypes.Uint16, + }, + exp: math.MaxInt16, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxUint16), + Typ: sqltypes.Uint16, + }, + exp: math.MaxUint16, + }, + + // Uint32 -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(0)), + Typ: sqltypes.Uint32, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(67)), + Typ: sqltypes.Uint32, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32), + Typ: sqltypes.Uint32, + }, + exp: math.MaxInt32, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxUint32), + Typ: sqltypes.Uint32, + }, + exp: math.MaxUint32, + }, + + // Uint64 -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Uint64, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Uint64, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Uint64, + }, + exp: math.MaxInt64, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Uint64, + }, + exp: math.MaxInt64, + }, + + // Float32 -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(0)), + Typ: sqltypes.Float32, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(123.456)), + Typ: sqltypes.Float32, + }, + exp: 123, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(-math.MaxFloat32)), + Typ: sqltypes.Float32, + }, + exp: math.MinInt64, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(math.MaxFloat32)), + Typ: sqltypes.Float32, + }, + exp: math.MaxInt64, + }, + + // Float64 -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(0)), + Typ: sqltypes.Float64, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(123.456)), + Typ: sqltypes.Float64, + }, + exp: 123, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(-math.MaxFloat32)), + Typ: sqltypes.Float64, + }, + exp: math.MinInt64, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(math.MaxFloat32)), + Typ: sqltypes.Float64, + }, + exp: math.MaxInt64, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(-math.MaxFloat64)), + Typ: sqltypes.Float64, + }, + exp: math.MinInt64, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(math.MaxFloat64)), + Typ: sqltypes.Float64, + }, + exp: math.MaxInt64, + }, + + // Decimal -> Int64 + { + val: sql.Value{ + Val: zeroDec, + Typ: sqltypes.Decimal, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: testDec, + Typ: sqltypes.Decimal, + }, + exp: 123, + }, + { + val: sql.Value{ + Val: minInt64Dec, + Typ: sqltypes.Decimal, + }, + exp: math.MinInt64, + }, + { + val: sql.Value{ + Val: maxInt64Dec, + Typ: sqltypes.Decimal, + }, + exp: math.MaxInt64, + }, + + // Year -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Year, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(1967)), + Typ: sqltypes.Year, + }, + exp: 1967, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(1901)), + Typ: sqltypes.Year, + }, + exp: 1901, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(2155)), + Typ: sqltypes.Year, + }, + exp: 2155, + }, + + // Date -> Int64 + { + val: sql.Value{ + Val: zeroDate, + Typ: sqltypes.Date, + }, + exp: 19700101, + }, + { + val: sql.Value{ + Val: testDate, + Typ: sqltypes.Date, + }, + exp: 20000102, + }, + + // Time -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Time, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Time, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64+1), + Typ: sqltypes.Time, + }, + exp: math.MinInt64, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Time, + }, + exp: math.MaxInt64, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Time, + }, + exp: -1, + }, + + // Datetime/Timestamp -> Int64 + { + val: sql.Value{ + Val: zeroDatetime, + Typ: sqltypes.Datetime, + }, + exp: 19700101000000, + }, + { + val: sql.Value{ + Val: testDatetime, + Typ: sqltypes.Datetime, + }, + exp: 20000102123456, + }, + + // Text -> Int64 + + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%v", test.val), func(t *testing.T) { + res, err := ConvertValueToInt64(ctx, test.val) + if test.err { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, test.exp, res) + }) + } +} + +func TestConvertValueToUint64(t *testing.T) { + // TODO +} + +func TestConvertValueToFloat64(t *testing.T) { + // TODO +} diff --git a/sql/values/encoding.go b/sql/values/encoding.go index 361aea2088..20d904fe98 100644 --- a/sql/values/encoding.go +++ b/sql/values/encoding.go @@ -183,7 +183,6 @@ func ReadFloat64(val []byte) float64 { } func ReadDecimal(val []byte) decimal.Decimal { - expectSize(val, DecimalSize) e := ReadInt32(val[:Int32Size]) s := ReadInt8(val[Int32Size : Int32Size+Int8Size]) b := big.NewInt(0).SetBytes(val[Int32Size+Int8Size:]) From 0386edc5dcbfd900f8c31ccff8d6ef04a6db6129 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 29 Oct 2025 12:30:51 -0700 Subject: [PATCH 42/59] limiting scope of changes --- sql/expression/comparison.go | 7 +- sql/types/datetime.go | 133 +--------- sql/types/enum.go | 5 + sql/types/number.go | 8 + sql/types/number_test.go | 480 +++++++++++++++++++++++++++++++---- sql/types/set.go | 5 + sql/types/strings.go | 102 +------- sql/types/time.go | 5 + 8 files changed, 462 insertions(+), 283 deletions(-) diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 6c01d0a0ad..c9612656e7 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -190,6 +190,7 @@ func (c *comparison) CompareValue(ctx *sql.Context, row sql.ValueRow) (int, erro } return types.Float64.(sql.ValueType).CompareValue(ctx, lv, rv) } + return lTyp.CompareValue(ctx, lv, rv) } @@ -203,10 +204,8 @@ func (c *comparison) IsValueExpression() bool { if !ok { return false } - if _, ok := c.LeftChild.Type().(sql.ValueType); !ok { - return false - } - if _, ok := c.RightChild.Type().(sql.ValueType); !ok { + // TODO: only allow comparisons between Integers, Floats, Decimals, Bits and Year for now + if !types.IsNumber(c.LeftChild.Type()) || !types.IsNumber(c.RightChild.Type()) { return false } return l.IsValueExpression() && r.IsValueExpression() diff --git a/sql/types/datetime.go b/sql/types/datetime.go index 958b154eaa..3dfdc6382c 100644 --- a/sql/types/datetime.go +++ b/sql/types/datetime.go @@ -191,25 +191,7 @@ func (t datetimeType) Compare(ctx context.Context, a interface{}, b interface{}) // CompareValue implements the ValueType interface func (t datetimeType) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error) { - if hasNulls, res := CompareNullValues(a, b); hasNulls { - return res, nil - } - at, err := ConvertValueToDatetime(ctx, a) - if err != nil { - return 0, err - } - bt, err := ConvertValueToDatetime(ctx, b) - if err != nil { - return 0, err - } - switch { - case at.Before(bt): - return -1, nil - case at.After(bt): - return 1, nil - default: - return 0, nil - } + panic("TODO: implement CompareValue for DatetimeType") } // Convert implements Type interface. @@ -591,116 +573,3 @@ func ValidateTimestamp(t time.Time) interface{} { } return t } - -func ConvertValueToDatetime(ctx *sql.Context, v sql.Value) (time.Time, error) { - switch v.Typ { - case sqltypes.Int8: - x := values.ReadInt8(v.Val) - if x == 0 { - return zeroTime, nil - } - return zeroTime, ErrConvertingToTime.New(x) - case sqltypes.Int16: - x := values.ReadInt16(v.Val) - if x == 0 { - return zeroTime, nil - } - return zeroTime, ErrConvertingToTime.New(x) - case sqltypes.Int24: - x := values.ReadInt24(v.Val) - if x == 0 { - return zeroTime, nil - } - return zeroTime, ErrConvertingToTime.New(x) - case sqltypes.Int32: - x := values.ReadInt32(v.Val) - if x == 0 { - return zeroTime, nil - } - return zeroTime, ErrConvertingToTime.New(x) - case sqltypes.Int64: - x := values.ReadInt64(v.Val) - if x == 0 { - return zeroTime, nil - } - return zeroTime, ErrConvertingToTime.New(x) - case sqltypes.Uint8: - x := values.ReadUint8(v.Val) - if x == 0 { - return zeroTime, nil - } - return zeroTime, ErrConvertingToTime.New(x) - case sqltypes.Uint16: - x := values.ReadUint16(v.Val) - if x == 0 { - return zeroTime, nil - } - return zeroTime, ErrConvertingToTime.New(x) - case sqltypes.Uint24: - x := values.ReadUint24(v.Val) - if x == 0 { - return zeroTime, nil - } - return zeroTime, ErrConvertingToTime.New(x) - case sqltypes.Uint32: - x := values.ReadUint32(v.Val) - if x == 0 { - return zeroTime, nil - } - return zeroTime, ErrConvertingToTime.New(x) - case sqltypes.Uint64: - x := values.ReadUint64(v.Val) - if x == 0 { - return zeroTime, nil - } - return zeroTime, ErrConvertingToTime.New(x) - case sqltypes.Float32: - x := values.ReadFloat32(v.Val) - if x == 0 { - return zeroTime, nil - } - return zeroTime, ErrConvertingToTime.New(x) - case sqltypes.Float64: - x := values.ReadFloat64(v.Val) - if x == 0 { - return zeroTime, nil - } - return zeroTime, ErrConvertingToTime.New(x) - case sqltypes.Decimal: - x := values.ReadDecimal(v.Val) - if x.IsZero() { - return zeroTime, nil - } - return zeroTime, ErrConvertingToTime.New(x) - case sqltypes.Year: - return zeroTime, nil - case sqltypes.Date: - x := values.ReadDate(v.Val) - return x, nil - case sqltypes.Time: - x := values.ReadInt64(v.Val) - if x == 0 { - return zeroTime, nil - } - return zeroTime, ErrConvertingToTime.New(x) - case sqltypes.Datetime, sqltypes.Timestamp: - x := values.ReadDatetime(v.Val) - return x, nil - case sqltypes.Text, sqltypes.Blob: - var err error - if v.Val == nil { - v.Val, err = v.WrappedVal.Unwrap(ctx) - if err != nil { - return zeroTime, err - } - } - val := values.ReadString(v.Val) - res, parsed, err := parseDatetime(val) - if !parsed { - return zeroTime, ErrConvertingToTime.New(v) - } - return res, err - default: - return zeroTime, sql.ErrInvalidBaseType.New(v.Typ.String(), "datetime") - } -} diff --git a/sql/types/enum.go b/sql/types/enum.go index 87388209cd..680ebed0c4 100644 --- a/sql/types/enum.go +++ b/sql/types/enum.go @@ -158,6 +158,11 @@ func (t EnumType) Compare(ctx context.Context, a interface{}, b interface{}) (in return 0, nil } +// CompareValue implements the ValueType interface +func (t EnumType) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error) { + panic("TODO: implement CompareValue for EnumType") +} + // Convert implements Type interface. func (t EnumType) Convert(ctx context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { if v == nil { diff --git a/sql/types/number.go b/sql/types/number.go index 05aa5a165e..b4fc5b56cb 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -1126,6 +1126,12 @@ func ConvertValueToInt64(ctx *sql.Context, v sql.Value) (int64, error) { return math.MinInt64, nil } return v.Round(0).IntPart(), nil + case sqltypes.Bit: + v := values.ReadUint64(v.Val) + if v > math.MaxInt64 { + return math.MaxInt64, nil + } + return int64(v), nil case sqltypes.Year: v := values.ReadUint16(v.Val) return int64(v), nil @@ -1216,6 +1222,8 @@ func ConvertValueToUint64(ctx *sql.Context, v sql.Value) (uint64, error) { return uint64(math.Round(ret)), nil } return uint64(v.Round(0).IntPart()), nil + case sqltypes.Bit: + return values.ReadUint64(v.Val), nil case sqltypes.Year: v := values.ReadUint16(v.Val) return uint64(v), nil diff --git a/sql/types/number_test.go b/sql/types/number_test.go index 33074e7994..f9d29d2c7a 100644 --- a/sql/types/number_test.go +++ b/sql/types/number_test.go @@ -723,18 +723,6 @@ func serializeDecimal(dec decimal.Decimal) []byte { return res } -func serializeDate(date time.Time) []byte { - var res uint32 - res += uint32(date.Year() << 16) - res += uint32(date.Month() << 8) - res += uint32(date.Day()) - return binary.LittleEndian.AppendUint32(nil, res) -} - -func serializeDatetime(date time.Time) []byte { - return binary.LittleEndian.AppendUint64(nil, uint64(date.UnixMicro())) -} - func TestConvertValueToInt64(t *testing.T) { ctx := sql.NewEmptyContext() @@ -743,12 +731,6 @@ func TestConvertValueToInt64(t *testing.T) { minInt64Dec := serializeDecimal(decimal.NewFromInt(math.MinInt64)) maxInt64Dec := serializeDecimal(decimal.NewFromInt(math.MaxInt64)) - zeroDate := serializeDate(time.Unix(0, 0).UTC()) - testDate := serializeDate(time.Date(2000, 01, 02, 0, 0, 0, 0, time.UTC)) - - zeroDatetime := serializeDatetime(time.Unix(0, 0).UTC()) - testDatetime := serializeDatetime(time.Date(2000, 01, 02, 12, 34, 56, 0, time.UTC)) - tests := []struct { val sql.Value exp int64 @@ -1119,6 +1101,36 @@ func TestConvertValueToInt64(t *testing.T) { exp: math.MaxInt64, }, + // Bit -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Bit, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Bit, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Bit, + }, + exp: math.MaxInt64, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Bit, + }, + exp: math.MaxInt64, + }, + // Year -> Int64 { val: sql.Value{ @@ -1148,78 +1160,458 @@ func TestConvertValueToInt64(t *testing.T) { }, exp: 2155, }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%v", test.val), func(t *testing.T) { + res, err := ConvertValueToInt64(ctx, test.val) + if test.err { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, test.exp, res) + }) + } +} + +func TestConvertValueToUint64(t *testing.T) { + ctx := sql.NewEmptyContext() + + zeroDec := serializeDecimal(decimal.Zero) + testDec := serializeDecimal(decimal.NewFromFloat(123.456)) + minInt64Dec := serializeDecimal(decimal.NewFromInt(math.MinInt64)) + maxInt64Dec := serializeDecimal(decimal.NewFromInt(math.MaxInt64)) + + tests := []struct { + val sql.Value + exp uint64 + err bool + }{ + // Int8 -> Uint64 + { + val: sql.Value{ + Val: []byte{0}, + Typ: sqltypes.Int8, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: []byte{67}, + Typ: sqltypes.Int8, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: []byte{128}, + Typ: sqltypes.Int8, + }, + exp: 128, + }, + { + val: sql.Value{ + Val: []byte{255}, + Typ: sqltypes.Int8, + }, + exp: 255, + }, + + // Int16 -> Uint64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Int16, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(67)), + Typ: sqltypes.Int16, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16), + Typ: sqltypes.Int16, + }, + exp: math.MaxInt16, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16+1), + Typ: sqltypes.Int16, + }, + exp: math.MaxInt16 + 1, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxUint16), + Typ: sqltypes.Int16, + }, + exp: math.MaxUint16, + }, - // Date -> Int64 + // Int32 -> Uint64 { val: sql.Value{ - Val: zeroDate, - Typ: sqltypes.Date, + Val: binary.LittleEndian.AppendUint32(nil, uint32(0)), + Typ: sqltypes.Int32, }, - exp: 19700101, + exp: 0, }, { val: sql.Value{ - Val: testDate, - Typ: sqltypes.Date, + Val: binary.LittleEndian.AppendUint32(nil, uint32(67)), + Typ: sqltypes.Int32, }, - exp: 20000102, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32), + Typ: sqltypes.Int32, + }, + exp: math.MaxInt32, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32+1), + Typ: sqltypes.Int32, + }, + exp: math.MaxInt32 + 1, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxUint32), + Typ: sqltypes.Int32, + }, + exp: math.MaxUint32, }, - // Time -> Int64 + // Int64 -> Uint64 { val: sql.Value{ Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), - Typ: sqltypes.Time, + Typ: sqltypes.Int64, }, exp: 0, }, { val: sql.Value{ Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), - Typ: sqltypes.Time, + Typ: sqltypes.Int64, }, exp: 67, }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Int64, + }, + exp: math.MaxInt64, + }, { val: sql.Value{ Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64+1), - Typ: sqltypes.Time, + Typ: sqltypes.Int64, }, - exp: math.MinInt64, + exp: math.MaxInt64 + 1, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Int64, + }, + exp: math.MaxUint64, + }, + + // Uint8 -> Uint64 + { + val: sql.Value{ + Val: []byte{0}, + Typ: sqltypes.Uint8, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: []byte{67}, + Typ: sqltypes.Uint8, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: []byte{128}, + Typ: sqltypes.Uint8, + }, + exp: 128, + }, + { + val: sql.Value{ + Val: []byte{255}, + Typ: sqltypes.Uint8, + }, + exp: 255, + }, + + // Uint16 -> Uint64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Uint16, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(67)), + Typ: sqltypes.Uint16, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16), + Typ: sqltypes.Uint16, + }, + exp: math.MaxInt16, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxUint16), + Typ: sqltypes.Uint16, + }, + exp: math.MaxUint16, + }, + + // Uint32 -> Uint64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(0)), + Typ: sqltypes.Uint32, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(67)), + Typ: sqltypes.Uint32, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32), + Typ: sqltypes.Uint32, + }, + exp: math.MaxInt32, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxUint32), + Typ: sqltypes.Uint32, + }, + exp: math.MaxUint32, + }, + + // Uint64 -> Uint64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Uint64, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Uint64, + }, + exp: 67, }, { val: sql.Value{ Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), - Typ: sqltypes.Time, + Typ: sqltypes.Uint64, }, exp: math.MaxInt64, }, { val: sql.Value{ Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), - Typ: sqltypes.Time, + Typ: sqltypes.Uint64, }, - exp: -1, + exp: math.MaxInt64, + }, + + // Float32 -> Uint64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(0)), + Typ: sqltypes.Float32, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(123.456)), + Typ: sqltypes.Float32, + }, + exp: 123, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(-math.MaxFloat32)), + Typ: sqltypes.Float32, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(math.MaxFloat32)), + Typ: sqltypes.Float32, + }, + exp: math.MaxInt64, + }, + + // Float64 -> Uint64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(0)), + Typ: sqltypes.Float64, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(123.456)), + Typ: sqltypes.Float64, + }, + exp: 123, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(-math.MaxFloat32)), + Typ: sqltypes.Float64, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(math.MaxFloat32)), + Typ: sqltypes.Float64, + }, + exp: math.MaxInt64, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(-math.MaxFloat64)), + Typ: sqltypes.Float64, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(math.MaxFloat64)), + Typ: sqltypes.Float64, + }, + exp: math.MaxInt64, }, - // Datetime/Timestamp -> Int64 + // Decimal -> Uint64 { val: sql.Value{ - Val: zeroDatetime, - Typ: sqltypes.Datetime, + Val: zeroDec, + Typ: sqltypes.Decimal, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: testDec, + Typ: sqltypes.Decimal, + }, + exp: 123, + }, + { + val: sql.Value{ + Val: minInt64Dec, + Typ: sqltypes.Decimal, }, - exp: 19700101000000, + exp: 0, }, { val: sql.Value{ - Val: testDatetime, - Typ: sqltypes.Datetime, + Val: maxInt64Dec, + Typ: sqltypes.Decimal, }, - exp: 20000102123456, + exp: math.MaxInt64, }, - // Text -> Int64 - + // Bit -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Bit, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Bit, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Bit, + }, + exp: math.MaxInt64, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Bit, + }, + exp: math.MaxInt64, + }, + + // Year -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Year, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(1967)), + Typ: sqltypes.Year, + }, + exp: 1967, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(1901)), + Typ: sqltypes.Year, + }, + exp: 1901, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(2155)), + Typ: sqltypes.Year, + }, + exp: 2155, + }, } for _, test := range tests { @@ -1235,10 +1627,6 @@ func TestConvertValueToInt64(t *testing.T) { } } -func TestConvertValueToUint64(t *testing.T) { - // TODO -} - func TestConvertValueToFloat64(t *testing.T) { // TODO } diff --git a/sql/types/set.go b/sql/types/set.go index 2ca5a96988..ccb642423b 100644 --- a/sql/types/set.go +++ b/sql/types/set.go @@ -156,6 +156,11 @@ func (t SetType) Compare(ctx context.Context, a interface{}, b interface{}) (int return 0, nil } +// CompareValue implements the ValueType interface +func (t SetType) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error) { + panic("TODO: implement CompareValue for SetType") +} + // Convert implements Type interface. // Returns the string representing the given value if applicable. func (t SetType) Convert(ctx context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { diff --git a/sql/types/strings.go b/sql/types/strings.go index e1ca2962e1..3dcdac8148 100644 --- a/sql/types/strings.go +++ b/sql/types/strings.go @@ -30,7 +30,6 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/encodings" - "github.com/dolthub/go-mysql-server/sql/values" ) const ( @@ -337,106 +336,7 @@ func (t StringType) Compare(ctx context.Context, a interface{}, b interface{}) ( // CompareValue implements the sql.ValueType interface. func (t StringType) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error) { - if hasNulls, res := CompareNullValues(a, b); hasNulls { - return res, nil - } - - // TODO: Possible to compare binary strings directly? - as, err := ConvertValueToString(ctx, a) - if err != nil { - return 0, err - } - bs, err := ConvertValueToString(ctx, b) - if err != nil { - return 0, err - } - - encoder := t.collation.CharacterSet().Encoder() - getRuneWeight := t.collation.Sorter() - for len(as) > 0 && len(bs) > 0 { - ar, aRead := encoder.NextRune(as) - br, bRead := encoder.NextRune(bs) - if aRead == 0 || bRead == 0 || aRead == utf8.RuneError || bRead == utf8.RuneError { - return 0, fmt.Errorf("malformed string encountered while comparing") - } - aWeight := getRuneWeight(ar) - bWeight := getRuneWeight(br) - if aWeight < bWeight { - return -1, nil - } - if aWeight > bWeight { - return 1, nil - } - as = as[aRead:] - bs = bs[bRead:] - } - - return 0, nil -} - -func ConvertValueToString(ctx *sql.Context, v sql.Value) (string, error) { - // TODO: fix allocation - switch v.Typ { - case sqltypes.Int8: - x := values.ReadInt8(v.Val) - return strconv.FormatInt(int64(x), 10), nil - case sqltypes.Int16: - x := values.ReadInt16(v.Val) - return strconv.FormatInt(int64(x), 10), nil - case sqltypes.Int24: - x := values.ReadInt24(v.Val) - return strconv.FormatInt(int64(x), 10), nil - case sqltypes.Int32: - x := values.ReadInt32(v.Val) - return strconv.FormatInt(int64(x), 10), nil - case sqltypes.Int64: - x := values.ReadInt64(v.Val) - return strconv.FormatInt(x, 10), nil - case sqltypes.Uint8: - x := values.ReadUint8(v.Val) - return strconv.FormatUint(uint64(x), 10), nil - case sqltypes.Uint16: - x := values.ReadUint16(v.Val) - return strconv.FormatUint(uint64(x), 10), nil - case sqltypes.Uint24: - x := values.ReadUint24(v.Val) - return strconv.FormatUint(uint64(x), 10), nil - case sqltypes.Uint32: - x := values.ReadUint32(v.Val) - return strconv.FormatUint(uint64(x), 10), nil - case sqltypes.Uint64: - x := values.ReadUint64(v.Val) - return strconv.FormatUint(x, 10), nil - case sqltypes.Float32: - x := values.ReadFloat32(v.Val) - return strconv.FormatFloat(float64(x), 'f', -1, 32), nil - case sqltypes.Float64: - x := values.ReadFloat64(v.Val) - return strconv.FormatFloat(x, 'f', -1, 64), nil - case sqltypes.Decimal: - x := values.ReadDecimal(v.Val) - return x.String(), nil - case sqltypes.Year: - x := values.ReadInt16(v.Val) - return strconv.FormatInt(int64(x), 10), nil - case sqltypes.Date: - x := values.ReadDate(v.Val) - return x.String(), nil - case sqltypes.Datetime, sqltypes.Timestamp: - x := values.ReadDatetime(v.Val) - return x.String(), nil - case sqltypes.Text, sqltypes.Blob: - var err error - if v.Val == nil { - v.Val, err = v.WrappedVal.Unwrap(ctx) - if err != nil { - return "", err - } - } - return values.ReadString(v.Val), nil - default: - return "", fmt.Errorf("unsupported type %v", v.Typ) - } + panic("TODO: implement CompareValue for StringTypes") } // Convert implements Type interface. diff --git a/sql/types/time.go b/sql/types/time.go index 6c7dd3b135..fb4e153fc1 100644 --- a/sql/types/time.go +++ b/sql/types/time.go @@ -100,6 +100,11 @@ func (t TimespanType_) Compare(s context.Context, a interface{}, b interface{}) return as.Compare(bs), nil } +// CompareValue implements the ValueType interface +func (t TimespanType_) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error) { + panic("TODO: implement CompareValue for TimespanType") +} + func (t TimespanType_) Convert(c context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { if v == nil { return nil, sql.InRange, nil From aa3dfc1032a69d255c07fa058095adc3b9acbcb4 Mon Sep 17 00:00:00 2001 From: jycor Date: Wed, 29 Oct 2025 19:32:57 +0000 Subject: [PATCH 43/59] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/types/number_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/types/number_test.go b/sql/types/number_test.go index f9d29d2c7a..0290dd3edd 100644 --- a/sql/types/number_test.go +++ b/sql/types/number_test.go @@ -17,7 +17,6 @@ package types import ( "encoding/binary" "fmt" - "github.com/shopspring/decimal" "math" "reflect" "strconv" @@ -26,6 +25,7 @@ import ( "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" + "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" From 706a096fe1f418b46a60571968ae2d18620d26bb Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 29 Oct 2025 16:37:34 -0700 Subject: [PATCH 44/59] conversion tests --- sql/expression/comparison_test.go | 3 +- sql/types/bit.go | 4 +- sql/types/number.go | 372 ++++++++---------- sql/types/number_test.go | 609 +++++++++++++++++++++++++++++- 4 files changed, 746 insertions(+), 242 deletions(-) diff --git a/sql/expression/comparison_test.go b/sql/expression/comparison_test.go index 4032e54ab5..ec904b7c62 100644 --- a/sql/expression/comparison_test.go +++ b/sql/expression/comparison_test.go @@ -226,6 +226,5 @@ func eval(t *testing.T, e sql.Expression, row sql.Row) interface{} { } func TestValueComparison(t *testing.T) { - require := require.New(t) - + // TODO } diff --git a/sql/types/bit.go b/sql/types/bit.go index 4e3719a582..ce24772911 100644 --- a/sql/types/bit.go +++ b/sql/types/bit.go @@ -109,11 +109,11 @@ func (t BitType_) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error) { return res, nil } - av, err := ConvertValueToUint64(ctx, a) + av, _, err := convertValueToUint64(ctx, a) if err != nil { return 0, err } - bv, err := ConvertValueToUint64(ctx, b) + bv, _, err := convertValueToUint64(ctx, b) if err != nil { return 0, err } diff --git a/sql/types/number.go b/sql/types/number.go index b4fc5b56cb..e10e668b6e 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -218,11 +218,11 @@ func (t NumberTypeImpl_) CompareValue(ctx *sql.Context, a, b sql.Value) (int, er switch t.baseType { case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint24, sqltypes.Uint32, sqltypes.Uint64: - ca, err := ConvertValueToUint64(ctx, a) + ca, _, err := convertValueToUint64(ctx, a) if err != nil { return 0, err } - cb, err := ConvertValueToUint64(ctx, b) + cb, _, err := convertValueToUint64(ctx, b) if err != nil { return 0, err } @@ -235,11 +235,11 @@ func (t NumberTypeImpl_) CompareValue(ctx *sql.Context, a, b sql.Value) (int, er } return +1, nil case sqltypes.Float32, sqltypes.Float64: - ca, err := convertValueToFloat64(t, a) + ca, err := convertValueToFloat64(ctx, a) if err != nil { return 0, err } - cb, err := convertValueToFloat64(t, b) + cb, err := convertValueToFloat64(ctx, b) if err != nil { return 0, err } @@ -252,11 +252,11 @@ func (t NumberTypeImpl_) CompareValue(ctx *sql.Context, a, b sql.Value) (int, er } return +1, nil default: - ca, err := ConvertValueToInt64(ctx, a) + ca, _, err := convertValueToInt64(ctx, a) if err != nil { return 0, err } - cb, err := ConvertValueToInt64(ctx, b) + cb, _, err := convertValueToInt64(ctx, b) if err != nil { return 0, err } @@ -1073,205 +1073,6 @@ func convertToInt64(t NumberTypeImpl_, v any, round Round) (int64, sql.ConvertIn } } -func ConvertValueToInt64(ctx *sql.Context, v sql.Value) (int64, error) { - switch v.Typ { - case sqltypes.Int8: - return int64(values.ReadInt8(v.Val)), nil - case sqltypes.Int16: - return int64(values.ReadInt16(v.Val)), nil - case sqltypes.Int24: - return int64(values.ReadInt24(v.Val)), nil - case sqltypes.Int32: - return int64(values.ReadInt32(v.Val)), nil - case sqltypes.Int64: - return values.ReadInt64(v.Val), nil - case sqltypes.Uint8: - return int64(values.ReadUint8(v.Val)), nil - case sqltypes.Uint16: - return int64(values.ReadUint16(v.Val)), nil - case sqltypes.Uint24: - return int64(values.ReadUint24(v.Val)), nil - case sqltypes.Uint32: - return int64(values.ReadUint32(v.Val)), nil - case sqltypes.Uint64: - v := values.ReadUint64(v.Val) - if v > math.MaxInt64 { - return math.MaxInt64, nil - } - return int64(v), nil - case sqltypes.Float32: - v := values.ReadFloat32(v.Val) - if v > float32(math.MaxInt64) { - return math.MaxInt64, nil - } - if v < float32(math.MinInt64) { - return math.MinInt64, nil - } - return int64(math.Round(float64(v))), nil - case sqltypes.Float64: - v := values.ReadFloat64(v.Val) - if v > float64(math.MaxInt64) { - return math.MaxInt64, nil - } - if v < float64(math.MinInt64) { - return math.MinInt64, nil - } - return int64(math.Round(v)), nil - case sqltypes.Decimal: - v := values.ReadDecimal(v.Val) - if v.GreaterThan(dec_int64_max) { - return math.MaxInt64, nil - } - if v.LessThan(dec_int64_min) { - return math.MinInt64, nil - } - return v.Round(0).IntPart(), nil - case sqltypes.Bit: - v := values.ReadUint64(v.Val) - if v > math.MaxInt64 { - return math.MaxInt64, nil - } - return int64(v), nil - case sqltypes.Year: - v := values.ReadUint16(v.Val) - return int64(v), nil - case sqltypes.Date: - v := values.ReadDate(v.Val) - var res int64 - res += int64(v.Year() * 10000) - res += int64(v.Month() * 100) - res += int64(v.Day()) - return res, nil - case sqltypes.Time: - v := values.ReadInt64(v.Val) - return v, nil - case sqltypes.Datetime, sqltypes.Timestamp: - v := values.ReadDatetime(v.Val) - var res int64 - res += int64(v.Year() * 1_00_00_00_00_00) - res += int64(v.Month() * 1_00_00_00_00) - res += int64(v.Day() * 1_00_00_00) - res += int64(v.Hour() * 1_00_00) - res += int64(v.Minute() * 1_00) - res += int64(v.Second()) - return res, nil - case sqltypes.Text, sqltypes.Blob: - var err error - if v.Val == nil { - v.Val, err = v.WrappedVal.Unwrap(ctx) - if err != nil { - return 0, err - } - } - val := values.ReadString(v.Val) - truncStr, didTrunc := TruncateStringToInt(val) - if didTrunc { - err = sql.ErrTruncatedIncorrect.New(v.Typ, val) - } - i, pErr := strconv.ParseInt(truncStr, 10, 64) - if pErr != nil { - return 0, sql.ErrInvalidValue.New(v, v.Typ.String()) - } - return i, err - default: - return 0, sql.ErrInvalidBaseType.New(v.Typ.String(), "number") - } -} - -func ConvertValueToUint64(ctx *sql.Context, v sql.Value) (uint64, error) { - switch v.Typ { - case sqltypes.Int8: - return uint64(values.ReadInt8(v.Val)), nil - case sqltypes.Int16: - return uint64(values.ReadInt16(v.Val)), nil - case sqltypes.Int24: - return uint64(values.ReadInt24(v.Val)), nil - case sqltypes.Int32: - return uint64(values.ReadInt32(v.Val)), nil - case sqltypes.Int64: - return uint64(values.ReadInt64(v.Val)), nil - case sqltypes.Uint8: - return uint64(values.ReadUint8(v.Val)), nil - case sqltypes.Uint16: - return uint64(values.ReadUint16(v.Val)), nil - case sqltypes.Uint24: - return uint64(values.ReadUint24(v.Val)), nil - case sqltypes.Uint32: - return uint64(values.ReadUint32(v.Val)), nil - case sqltypes.Uint64: - return values.ReadUint64(v.Val), nil - case sqltypes.Float32: - v := values.ReadFloat32(v.Val) - if v >= float32(math.MaxUint64) { - return math.MaxUint64, nil - } - return uint64(math.Round(float64(v))), nil - case sqltypes.Float64: - v := values.ReadFloat64(v.Val) - if v > float64(math.MaxUint64) { - return math.MaxUint64, nil - } - return uint64(math.Round(v)), nil - case sqltypes.Decimal: - v := values.ReadDecimal(v.Val) - if v.GreaterThan(dec_uint64_max) { - return math.MaxUint64, nil - } - if v.LessThan(dec_zero) { - ret, _ := dec_uint64_max.Sub(v).Float64() - return uint64(math.Round(ret)), nil - } - return uint64(v.Round(0).IntPart()), nil - case sqltypes.Bit: - return values.ReadUint64(v.Val), nil - case sqltypes.Year: - v := values.ReadUint16(v.Val) - return uint64(v), nil - case sqltypes.Date: - v := values.ReadDate(v.Val) - return uint64(v.UTC().Unix()), nil - case sqltypes.Time: - v := values.ReadInt64(v.Val) - return uint64(v), nil - case sqltypes.Datetime, sqltypes.Timestamp: - v := values.ReadDatetime(v.Val) - return uint64(v.UTC().Unix()), nil - case sqltypes.Text, sqltypes.Blob: - var err error - if v.Val == nil { - v.Val, err = v.WrappedVal.Unwrap(ctx) - if err != nil { - return 0, err - } - } - val := values.ReadString(v.Val) - - truncStr, didTrunc := TruncateStringToInt(val) - if didTrunc { - err = sql.ErrTruncatedIncorrect.New(v.Typ, val) - } - - var neg bool - if truncStr[0] == '+' { - truncStr = truncStr[1:] - } else if truncStr[0] == '-' { - truncStr = truncStr[1:] - neg = true - } - - i, pErr := strconv.ParseUint(truncStr, 10, 64) - if errors.Is(pErr, strconv.ErrRange) { - return math.MaxUint64, nil - } - if neg { - return math.MaxUint64 - i + 1, err - } - return i, err - default: - return 0, sql.ErrInvalidBaseType.New(v.Typ.String(), "number") - } -} - func convertToUint64(t NumberTypeImpl_, v any, round Round) (uint64, sql.ConvertInRange, error) { switch v := v.(type) { case time.Time: @@ -1460,34 +1261,167 @@ func convertToFloat64(t NumberTypeImpl_, v interface{}) (float64, error) { } } -func convertValueToFloat64(t NumberTypeImpl_, v sql.Value) (float64, error) { +func convertValueToInt64(ctx *sql.Context, v sql.Value) (int64, sql.ConvertInRange, error) { switch v.Typ { - case query.Type_INT8: + case sqltypes.Int8: + return int64(values.ReadInt8(v.Val)), sql.InRange, nil + case sqltypes.Int16: + return int64(values.ReadInt16(v.Val)), sql.InRange, nil + case sqltypes.Int24: + return int64(values.ReadInt24(v.Val)), sql.InRange, nil + case sqltypes.Int32: + return int64(values.ReadInt32(v.Val)), sql.InRange, nil + case sqltypes.Int64: + return values.ReadInt64(v.Val), sql.InRange, nil + case sqltypes.Uint8: + return int64(values.ReadUint8(v.Val)), sql.InRange, nil + case sqltypes.Uint16: + return int64(values.ReadUint16(v.Val)), sql.InRange, nil + case sqltypes.Uint24: + return int64(values.ReadUint24(v.Val)), sql.InRange, nil + case sqltypes.Uint32: + return int64(values.ReadUint32(v.Val)), sql.InRange, nil + case sqltypes.Uint64: + x := values.ReadUint64(v.Val) + if x > math.MaxInt64 { + return math.MaxInt64, sql.OutOfRange, nil + } + return int64(x), sql.InRange, nil + case sqltypes.Float32: + x := values.ReadFloat32(v.Val) + if x > float32(math.MaxInt64) { + return math.MaxInt64, sql.OutOfRange, nil + } + if x < float32(math.MinInt64) { + return math.MinInt64, sql.OutOfRange, nil + } + return int64(math.Round(float64(x))), sql.InRange, nil + case sqltypes.Float64: + x := values.ReadFloat64(v.Val) + if x > float64(math.MaxInt64) { + return math.MaxInt64, sql.OutOfRange, nil + } + if x < float64(math.MinInt64) { + return math.MinInt64, sql.OutOfRange, nil + } + return int64(math.Round(x)), sql.InRange, nil + case sqltypes.Decimal: + x := values.ReadDecimal(v.Val) + if x.GreaterThan(dec_int64_max) { + return math.MaxInt64, sql.OutOfRange, nil + } + if x.LessThan(dec_int64_min) { + return math.MinInt64, sql.OutOfRange, nil + } + return x.Round(0).IntPart(), sql.InRange, nil + case sqltypes.Bit: + x := values.ReadUint64(v.Val) + if x > math.MaxInt64 { + return math.MaxInt64, sql.OutOfRange, nil + } + return int64(x), sql.InRange, nil + case sqltypes.Year: + return int64(values.ReadUint16(v.Val)), sql.InRange, nil + default: + return 0, sql.InRange, sql.ErrInvalidBaseType.New(v.Typ.String(), "number") + } +} + +func convertValueToUint64(ctx *sql.Context, v sql.Value) (uint64, sql.ConvertInRange, error) { + switch v.Typ { + case sqltypes.Int8: + return uint64(values.ReadInt8(v.Val)), sql.InRange, nil + case sqltypes.Int16: + return uint64(values.ReadInt16(v.Val)), sql.InRange, nil + case sqltypes.Int24: + return uint64(values.ReadInt24(v.Val)), sql.InRange, nil + case sqltypes.Int32: + return uint64(values.ReadInt32(v.Val)), sql.InRange, nil + case sqltypes.Int64: + return uint64(values.ReadInt64(v.Val)), sql.InRange, nil + case sqltypes.Uint8: + return uint64(values.ReadUint8(v.Val)), sql.InRange, nil + case sqltypes.Uint16: + return uint64(values.ReadUint16(v.Val)), sql.InRange, nil + case sqltypes.Uint24: + return uint64(values.ReadUint24(v.Val)), sql.InRange, nil + case sqltypes.Uint32: + return uint64(values.ReadUint32(v.Val)), sql.InRange, nil + case sqltypes.Uint64: + return values.ReadUint64(v.Val), sql.InRange, nil + case sqltypes.Float32: + x := values.ReadFloat32(v.Val) + if x > float32(math.MaxUint64) { + return math.MaxUint64, sql.OutOfRange, nil + } + if x < 0 { + return uint64(x), sql.OutOfRange, nil + } + return uint64(math.Round(float64(x))), sql.InRange, nil + case sqltypes.Float64: + x := values.ReadFloat64(v.Val) + if x > float64(math.MaxUint64) { + return math.MaxUint64, sql.OutOfRange, nil + } + if x < 0 { + return uint64(x), sql.OutOfRange, nil + } + return uint64(math.Round(x)), sql.InRange, nil + case sqltypes.Decimal: + x := values.ReadDecimal(v.Val) + if x.GreaterThan(dec_uint64_max) { + return math.MaxUint64, sql.OutOfRange, nil + } + if x.LessThan(dec_zero) { + ret, _ := dec_uint64_max.Sub(x).Float64() + return uint64(math.Round(ret)), sql.OutOfRange, nil + } + return uint64(x.Round(0).IntPart()), sql.InRange, nil + case sqltypes.Bit: + return values.ReadUint64(v.Val), sql.InRange, nil + case sqltypes.Year: + return uint64(values.ReadUint16(v.Val)), sql.InRange, nil + default: + return 0, sql.InRange, sql.ErrInvalidBaseType.New(v.Typ.String(), "number") + } +} + +func convertValueToFloat64(ctx *sql.Context, v sql.Value) (float64, error) { + switch v.Typ { + case sqltypes.Int8: return float64(values.ReadInt8(v.Val)), nil - case query.Type_INT16: + case sqltypes.Int16: return float64(values.ReadInt16(v.Val)), nil - case query.Type_INT24: + case sqltypes.Int24: return float64(values.ReadInt24(v.Val)), nil - case query.Type_INT32: + case sqltypes.Int32: return float64(values.ReadInt32(v.Val)), nil - case query.Type_INT64: + case sqltypes.Int64: return float64(values.ReadInt64(v.Val)), nil - case query.Type_UINT8: + case sqltypes.Uint8: return float64(values.ReadUint8(v.Val)), nil - case query.Type_UINT16: + case sqltypes.Uint16: return float64(values.ReadUint16(v.Val)), nil - case query.Type_UINT24: + case sqltypes.Uint24: return float64(values.ReadUint24(v.Val)), nil - case query.Type_UINT32: + case sqltypes.Uint32: return float64(values.ReadUint32(v.Val)), nil - case query.Type_UINT64: + case sqltypes.Uint64: return float64(values.ReadUint64(v.Val)), nil - case query.Type_FLOAT32: + case sqltypes.Float32: return float64(values.ReadFloat32(v.Val)), nil - case query.Type_FLOAT64: + case sqltypes.Float64: return values.ReadFloat64(v.Val), nil + case sqltypes.Decimal: + x := values.ReadDecimal(v.Val) + f, _ := x.Float64() + return f, nil + case sqltypes.Bit: + return float64(values.ReadUint64(v.Val)), nil + case sqltypes.Year: + return float64(values.ReadUint16(v.Val)), nil default: - panic(sql.ErrInvalidBaseType.New(t.baseType.String(), "number")) + return 0, sql.ErrInvalidBaseType.New(v.Typ.String(), "number") } } diff --git a/sql/types/number_test.go b/sql/types/number_test.go index 0290dd3edd..766c091090 100644 --- a/sql/types/number_test.go +++ b/sql/types/number_test.go @@ -734,6 +734,7 @@ func TestConvertValueToInt64(t *testing.T) { tests := []struct { val sql.Value exp int64 + rng sql.ConvertInRange err bool }{ // Int8 -> Int64 @@ -743,6 +744,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Int8, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -750,6 +752,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Int8, }, exp: 67, + rng: sql.InRange, }, { val: sql.Value{ @@ -757,6 +760,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Int8, }, exp: -128, + rng: sql.InRange, }, { val: sql.Value{ @@ -764,6 +768,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Int8, }, exp: -1, + rng: sql.InRange, }, // Int16 -> Int64 @@ -773,6 +778,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Int16, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -780,6 +786,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Int16, }, exp: 67, + rng: sql.InRange, }, { val: sql.Value{ @@ -787,6 +794,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Int16, }, exp: math.MinInt16, + rng: sql.InRange, }, { val: sql.Value{ @@ -794,6 +802,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Int16, }, exp: math.MaxInt16, + rng: sql.InRange, }, { val: sql.Value{ @@ -801,6 +810,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Int16, }, exp: -1, + rng: sql.InRange, }, // Int32 -> Int64 @@ -810,6 +820,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Int32, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -817,6 +828,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Int32, }, exp: 67, + rng: sql.InRange, }, { val: sql.Value{ @@ -824,6 +836,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Int32, }, exp: math.MinInt32, + rng: sql.InRange, }, { val: sql.Value{ @@ -831,6 +844,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Int32, }, exp: math.MaxInt32, + rng: sql.InRange, }, { val: sql.Value{ @@ -838,6 +852,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Int32, }, exp: -1, + rng: sql.InRange, }, // Int64 -> Int64 @@ -847,6 +862,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Int64, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -854,6 +870,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Int64, }, exp: 67, + rng: sql.InRange, }, { val: sql.Value{ @@ -861,6 +878,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Int64, }, exp: math.MinInt64, + rng: sql.InRange, }, { val: sql.Value{ @@ -868,6 +886,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Int64, }, exp: math.MaxInt64, + rng: sql.InRange, }, { val: sql.Value{ @@ -875,6 +894,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Int64, }, exp: -1, + rng: sql.InRange, }, // Uint8 -> Int64 @@ -884,6 +904,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Uint8, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -891,6 +912,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Uint8, }, exp: 67, + rng: sql.InRange, }, { val: sql.Value{ @@ -898,6 +920,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Uint8, }, exp: 128, + rng: sql.InRange, }, { val: sql.Value{ @@ -905,6 +928,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Uint8, }, exp: 255, + rng: sql.InRange, }, // Uint16 -> Int64 @@ -914,6 +938,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Uint16, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -921,6 +946,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Uint16, }, exp: 67, + rng: sql.InRange, }, { val: sql.Value{ @@ -928,6 +954,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Uint16, }, exp: math.MaxInt16, + rng: sql.InRange, }, { val: sql.Value{ @@ -935,6 +962,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Uint16, }, exp: math.MaxUint16, + rng: sql.InRange, }, // Uint32 -> Int64 @@ -944,6 +972,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Uint32, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -951,6 +980,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Uint32, }, exp: 67, + rng: sql.InRange, }, { val: sql.Value{ @@ -958,6 +988,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Uint32, }, exp: math.MaxInt32, + rng: sql.InRange, }, { val: sql.Value{ @@ -965,6 +996,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Uint32, }, exp: math.MaxUint32, + rng: sql.InRange, }, // Uint64 -> Int64 @@ -974,6 +1006,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Uint64, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -981,6 +1014,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Uint64, }, exp: 67, + rng: sql.InRange, }, { val: sql.Value{ @@ -988,6 +1022,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Uint64, }, exp: math.MaxInt64, + rng: sql.InRange, }, { val: sql.Value{ @@ -995,6 +1030,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Uint64, }, exp: math.MaxInt64, + rng: sql.OutOfRange, }, // Float32 -> Int64 @@ -1004,6 +1040,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Float32, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -1011,6 +1048,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Float32, }, exp: 123, + rng: sql.InRange, }, { val: sql.Value{ @@ -1018,6 +1056,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Float32, }, exp: math.MinInt64, + rng: sql.OutOfRange, }, { val: sql.Value{ @@ -1025,6 +1064,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Float32, }, exp: math.MaxInt64, + rng: sql.OutOfRange, }, // Float64 -> Int64 @@ -1034,6 +1074,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Float64, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -1041,6 +1082,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Float64, }, exp: 123, + rng: sql.InRange, }, { val: sql.Value{ @@ -1048,6 +1090,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Float64, }, exp: math.MinInt64, + rng: sql.OutOfRange, }, { val: sql.Value{ @@ -1055,6 +1098,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Float64, }, exp: math.MaxInt64, + rng: sql.OutOfRange, }, { val: sql.Value{ @@ -1062,6 +1106,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Float64, }, exp: math.MinInt64, + rng: sql.OutOfRange, }, { val: sql.Value{ @@ -1069,6 +1114,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Float64, }, exp: math.MaxInt64, + rng: sql.OutOfRange, }, // Decimal -> Int64 @@ -1078,6 +1124,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Decimal, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -1085,6 +1132,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Decimal, }, exp: 123, + rng: sql.InRange, }, { val: sql.Value{ @@ -1092,6 +1140,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Decimal, }, exp: math.MinInt64, + rng: sql.InRange, }, { val: sql.Value{ @@ -1099,6 +1148,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Decimal, }, exp: math.MaxInt64, + rng: sql.InRange, }, // Bit -> Int64 @@ -1108,6 +1158,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Bit, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -1115,6 +1166,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Bit, }, exp: 67, + rng: sql.InRange, }, { val: sql.Value{ @@ -1122,6 +1174,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Bit, }, exp: math.MaxInt64, + rng: sql.InRange, }, { val: sql.Value{ @@ -1129,6 +1182,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Bit, }, exp: math.MaxInt64, + rng: sql.OutOfRange, }, // Year -> Int64 @@ -1138,6 +1192,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Year, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -1145,6 +1200,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Year, }, exp: 1967, + rng: sql.InRange, }, { val: sql.Value{ @@ -1152,6 +1208,7 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Year, }, exp: 1901, + rng: sql.InRange, }, { val: sql.Value{ @@ -1159,18 +1216,20 @@ func TestConvertValueToInt64(t *testing.T) { Typ: sqltypes.Year, }, exp: 2155, + rng: sql.InRange, }, } for _, test := range tests { t.Run(fmt.Sprintf("%v", test.val), func(t *testing.T) { - res, err := ConvertValueToInt64(ctx, test.val) + res, rng, err := convertValueToInt64(ctx, test.val) if test.err { require.Error(t, err) return } require.NoError(t, err) require.Equal(t, test.exp, res) + require.Equal(t, test.rng, rng) }) } } @@ -1186,6 +1245,7 @@ func TestConvertValueToUint64(t *testing.T) { tests := []struct { val sql.Value exp uint64 + rng sql.ConvertInRange err bool }{ // Int8 -> Uint64 @@ -1195,6 +1255,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Int8, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -1202,20 +1263,23 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Int8, }, exp: 67, + rng: sql.InRange, }, { val: sql.Value{ - Val: []byte{128}, + Val: []byte{127}, Typ: sqltypes.Int8, }, - exp: 128, + exp: 127, + rng: sql.InRange, }, { val: sql.Value{ Val: []byte{255}, Typ: sqltypes.Int8, }, - exp: 255, + exp: math.MaxUint64, + rng: sql.InRange, }, // Int16 -> Uint64 @@ -1225,6 +1289,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Int16, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -1232,6 +1297,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Int16, }, exp: 67, + rng: sql.InRange, }, { val: sql.Value{ @@ -1239,20 +1305,23 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Int16, }, exp: math.MaxInt16, + rng: sql.InRange, }, { val: sql.Value{ Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16+1), Typ: sqltypes.Int16, }, - exp: math.MaxInt16 + 1, + exp: math.MaxUint64 - math.MaxInt16, + rng: sql.InRange, }, { val: sql.Value{ Val: binary.LittleEndian.AppendUint16(nil, math.MaxUint16), Typ: sqltypes.Int16, }, - exp: math.MaxUint16, + exp: math.MaxUint64, + rng: sql.InRange, }, // Int32 -> Uint64 @@ -1262,6 +1331,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Int32, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -1269,6 +1339,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Int32, }, exp: 67, + rng: sql.InRange, }, { val: sql.Value{ @@ -1276,20 +1347,23 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Int32, }, exp: math.MaxInt32, + rng: sql.InRange, }, { val: sql.Value{ Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32+1), Typ: sqltypes.Int32, }, - exp: math.MaxInt32 + 1, + exp: math.MaxUint64 - math.MaxInt32, + rng: sql.InRange, }, { val: sql.Value{ Val: binary.LittleEndian.AppendUint32(nil, math.MaxUint32), Typ: sqltypes.Int32, }, - exp: math.MaxUint32, + exp: math.MaxUint64, + rng: sql.InRange, }, // Int64 -> Uint64 @@ -1299,6 +1373,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Int64, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -1306,6 +1381,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Int64, }, exp: 67, + rng: sql.InRange, }, { val: sql.Value{ @@ -1313,6 +1389,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Int64, }, exp: math.MaxInt64, + rng: sql.InRange, }, { val: sql.Value{ @@ -1320,6 +1397,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Int64, }, exp: math.MaxInt64 + 1, + rng: sql.InRange, }, { val: sql.Value{ @@ -1327,6 +1405,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Int64, }, exp: math.MaxUint64, + rng: sql.InRange, }, // Uint8 -> Uint64 @@ -1336,6 +1415,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Uint8, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -1343,6 +1423,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Uint8, }, exp: 67, + rng: sql.InRange, }, { val: sql.Value{ @@ -1350,6 +1431,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Uint8, }, exp: 128, + rng: sql.InRange, }, { val: sql.Value{ @@ -1357,6 +1439,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Uint8, }, exp: 255, + rng: sql.InRange, }, // Uint16 -> Uint64 @@ -1366,6 +1449,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Uint16, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -1373,6 +1457,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Uint16, }, exp: 67, + rng: sql.InRange, }, { val: sql.Value{ @@ -1380,6 +1465,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Uint16, }, exp: math.MaxInt16, + rng: sql.InRange, }, { val: sql.Value{ @@ -1387,6 +1473,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Uint16, }, exp: math.MaxUint16, + rng: sql.InRange, }, // Uint32 -> Uint64 @@ -1396,6 +1483,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Uint32, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -1403,6 +1491,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Uint32, }, exp: 67, + rng: sql.InRange, }, { val: sql.Value{ @@ -1410,6 +1499,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Uint32, }, exp: math.MaxInt32, + rng: sql.InRange, }, { val: sql.Value{ @@ -1417,6 +1507,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Uint32, }, exp: math.MaxUint32, + rng: sql.InRange, }, // Uint64 -> Uint64 @@ -1426,6 +1517,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Uint64, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -1433,6 +1525,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Uint64, }, exp: 67, + rng: sql.InRange, }, { val: sql.Value{ @@ -1440,13 +1533,15 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Uint64, }, exp: math.MaxInt64, + rng: sql.InRange, }, { val: sql.Value{ Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), Typ: sqltypes.Uint64, }, - exp: math.MaxInt64, + exp: math.MaxUint64, + rng: sql.InRange, }, // Float32 -> Uint64 @@ -1456,6 +1551,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Float32, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -1463,6 +1559,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Float32, }, exp: 123, + rng: sql.InRange, }, { val: sql.Value{ @@ -1470,13 +1567,15 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Float32, }, exp: 0, + rng: sql.OutOfRange, }, { val: sql.Value{ Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(math.MaxFloat32)), Typ: sqltypes.Float32, }, - exp: math.MaxInt64, + exp: math.MaxUint64, + rng: sql.OutOfRange, }, // Float64 -> Uint64 @@ -1486,6 +1585,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Float64, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -1493,6 +1593,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Float64, }, exp: 123, + rng: sql.InRange, }, { val: sql.Value{ @@ -1500,13 +1601,15 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Float64, }, exp: 0, + rng: sql.OutOfRange, }, { val: sql.Value{ Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(math.MaxFloat32)), Typ: sqltypes.Float64, }, - exp: math.MaxInt64, + exp: math.MaxUint64, + rng: sql.OutOfRange, }, { val: sql.Value{ @@ -1514,13 +1617,15 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Float64, }, exp: 0, + rng: sql.OutOfRange, }, { val: sql.Value{ Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(math.MaxFloat64)), Typ: sqltypes.Float64, }, - exp: math.MaxInt64, + exp: math.MaxUint64, + rng: sql.OutOfRange, }, // Decimal -> Uint64 @@ -1530,6 +1635,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Decimal, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -1537,13 +1643,15 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Decimal, }, exp: 123, + rng: sql.InRange, }, { val: sql.Value{ Val: minInt64Dec, Typ: sqltypes.Decimal, }, - exp: 0, + exp: math.MaxUint64, + rng: sql.OutOfRange, }, { val: sql.Value{ @@ -1551,15 +1659,17 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Decimal, }, exp: math.MaxInt64, + rng: sql.InRange, }, - // Bit -> Int64 + // Bit -> Uint64 { val: sql.Value{ Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), Typ: sqltypes.Bit, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -1567,6 +1677,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Bit, }, exp: 67, + rng: sql.InRange, }, { val: sql.Value{ @@ -1574,22 +1685,25 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Bit, }, exp: math.MaxInt64, + rng: sql.InRange, }, { val: sql.Value{ Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), Typ: sqltypes.Bit, }, - exp: math.MaxInt64, + exp: math.MaxUint64, + rng: sql.InRange, }, - // Year -> Int64 + // Year -> Uint64 { val: sql.Value{ Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), Typ: sqltypes.Year, }, exp: 0, + rng: sql.InRange, }, { val: sql.Value{ @@ -1597,6 +1711,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Year, }, exp: 1967, + rng: sql.InRange, }, { val: sql.Value{ @@ -1604,6 +1719,7 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Year, }, exp: 1901, + rng: sql.InRange, }, { val: sql.Value{ @@ -1611,22 +1727,477 @@ func TestConvertValueToUint64(t *testing.T) { Typ: sqltypes.Year, }, exp: 2155, + rng: sql.InRange, }, } for _, test := range tests { - t.Run(fmt.Sprintf("%v", test.val), func(t *testing.T) { - res, err := ConvertValueToInt64(ctx, test.val) + t.Run(fmt.Sprintf("Val: %v Typ: %v to UINT64", test.val.Val, test.val.Typ), func(t *testing.T) { + res, rng, err := convertValueToUint64(ctx, test.val) if test.err { require.Error(t, err) return } require.NoError(t, err) require.Equal(t, test.exp, res) + require.Equal(t, test.rng, rng) }) } } func TestConvertValueToFloat64(t *testing.T) { - // TODO + ctx := sql.NewEmptyContext() + + zeroDec := serializeDecimal(decimal.Zero) + testDec := serializeDecimal(decimal.NewFromFloat(123.456)) + minInt64Dec := serializeDecimal(decimal.NewFromInt(math.MinInt64)) + maxInt64Dec := serializeDecimal(decimal.NewFromInt(math.MaxInt64)) + + tests := []struct { + val sql.Value + exp float64 + err bool + }{ + // Int8 -> Float64 + { + val: sql.Value{ + Val: []byte{0}, + Typ: sqltypes.Int8, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: []byte{67}, + Typ: sqltypes.Int8, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: []byte{127}, + Typ: sqltypes.Int8, + }, + exp: 127, + }, + { + val: sql.Value{ + Val: []byte{255}, + Typ: sqltypes.Int8, + }, + exp: -1, + }, + + // Int16 -> Float64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Int16, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(67)), + Typ: sqltypes.Int16, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16), + Typ: sqltypes.Int16, + }, + exp: math.MaxInt16, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16+1), + Typ: sqltypes.Int16, + }, + exp: math.MinInt16, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxUint16), + Typ: sqltypes.Int16, + }, + exp: -1, + }, + + // Int32 -> Float64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(0)), + Typ: sqltypes.Int32, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(67)), + Typ: sqltypes.Int32, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32), + Typ: sqltypes.Int32, + }, + exp: math.MaxInt32, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32+1), + Typ: sqltypes.Int32, + }, + exp: math.MinInt32, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxUint32), + Typ: sqltypes.Int32, + }, + exp: -1, + }, + + // Int64 -> Float64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Int64, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Int64, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Int64, + }, + exp: math.MaxInt64, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64+1), + Typ: sqltypes.Int64, + }, + exp: math.MinInt64, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Int64, + }, + exp: -1, + }, + + // Uint8 -> Float64 + { + val: sql.Value{ + Val: []byte{0}, + Typ: sqltypes.Uint8, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: []byte{67}, + Typ: sqltypes.Uint8, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: []byte{128}, + Typ: sqltypes.Uint8, + }, + exp: 128, + }, + { + val: sql.Value{ + Val: []byte{255}, + Typ: sqltypes.Uint8, + }, + exp: 255, + }, + + // Uint16 -> Float64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Uint16, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(67)), + Typ: sqltypes.Uint16, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16), + Typ: sqltypes.Uint16, + }, + exp: math.MaxInt16, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxUint16), + Typ: sqltypes.Uint16, + }, + exp: math.MaxUint16, + }, + + // Uint32 -> Float64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(0)), + Typ: sqltypes.Uint32, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(67)), + Typ: sqltypes.Uint32, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32), + Typ: sqltypes.Uint32, + }, + exp: math.MaxInt32, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxUint32), + Typ: sqltypes.Uint32, + }, + exp: math.MaxUint32, + }, + + // Uint64 -> Float64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Uint64, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Uint64, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Uint64, + }, + exp: math.MaxInt64, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Uint64, + }, + exp: math.MaxUint64, + }, + + // Float32 -> Float64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(0)), + Typ: sqltypes.Float32, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(123.456)), + Typ: sqltypes.Float32, + }, + exp: 123.456, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(-math.MaxFloat32)), + Typ: sqltypes.Float32, + }, + exp: -math.MaxFloat32, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(math.MaxFloat32)), + Typ: sqltypes.Float32, + }, + exp: math.MaxFloat32, + }, + + // Float64 -> Float64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(0)), + Typ: sqltypes.Float64, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(123.456)), + Typ: sqltypes.Float64, + }, + exp: 123, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(-math.MaxFloat32)), + Typ: sqltypes.Float64, + }, + exp: -math.MaxFloat32, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(math.MaxFloat32)), + Typ: sqltypes.Float64, + }, + exp: math.MaxFloat32, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(-math.MaxFloat64)), + Typ: sqltypes.Float64, + }, + exp: -math.MaxFloat64, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(math.MaxFloat64)), + Typ: sqltypes.Float64, + }, + exp: math.MaxFloat64, + }, + + // Decimal -> Float64 + { + val: sql.Value{ + Val: zeroDec, + Typ: sqltypes.Decimal, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: testDec, + Typ: sqltypes.Decimal, + }, + exp: 123.456, + }, + { + val: sql.Value{ + Val: minInt64Dec, + Typ: sqltypes.Decimal, + }, + exp: math.MinInt64, + }, + { + val: sql.Value{ + Val: maxInt64Dec, + Typ: sqltypes.Decimal, + }, + exp: math.MaxInt64, + }, + + // Bit -> Float64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Bit, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Bit, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Bit, + }, + exp: math.MaxInt64, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Bit, + }, + exp: math.MaxUint64, + }, + + // Year -> Float64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Year, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(1967)), + Typ: sqltypes.Year, + }, + exp: 1967, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(1901)), + Typ: sqltypes.Year, + }, + exp: 1901, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(2155)), + Typ: sqltypes.Year, + }, + exp: 2155, + }, + } + + epsilon := 0.01 + for _, test := range tests { + t.Run(fmt.Sprintf("%v", test.val), func(t *testing.T) { + res, err := convertValueToFloat64(ctx, test.val) + if test.err { + require.Error(t, err) + return + } + require.NoError(t, err) + if test.exp == 0 { + require.Zero(t, res) + return + } + require.InEpsilonf(t, test.exp, res, epsilon, fmt.Sprintf("Actual is: %v", res)) + }) + } } From 4caa38dedecd14733ea96ebe7ff1253b818ae2d5 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 29 Oct 2025 17:02:02 -0700 Subject: [PATCH 45/59] remove bad tests --- sql/types/number_test.go | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/sql/types/number_test.go b/sql/types/number_test.go index 766c091090..7bcacaffc5 100644 --- a/sql/types/number_test.go +++ b/sql/types/number_test.go @@ -1563,7 +1563,7 @@ func TestConvertValueToUint64(t *testing.T) { }, { val: sql.Value{ - Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(-math.MaxFloat32)), + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(math.MaxFloat32)), Typ: sqltypes.Float32, }, exp: 0, @@ -1595,14 +1595,6 @@ func TestConvertValueToUint64(t *testing.T) { exp: 123, rng: sql.InRange, }, - { - val: sql.Value{ - Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(-math.MaxFloat32)), - Typ: sqltypes.Float64, - }, - exp: 0, - rng: sql.OutOfRange, - }, { val: sql.Value{ Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(math.MaxFloat32)), @@ -1611,14 +1603,6 @@ func TestConvertValueToUint64(t *testing.T) { exp: math.MaxUint64, rng: sql.OutOfRange, }, - { - val: sql.Value{ - Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(-math.MaxFloat64)), - Typ: sqltypes.Float64, - }, - exp: 0, - rng: sql.OutOfRange, - }, { val: sql.Value{ Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(math.MaxFloat64)), From 87ac0f521663e5cf9f1cc1f8ae600c023c7eabd8 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 29 Oct 2025 17:24:28 -0700 Subject: [PATCH 46/59] fix --- sql/types/number_test.go | 8 -------- 1 file changed, 8 deletions(-) diff --git a/sql/types/number_test.go b/sql/types/number_test.go index 7bcacaffc5..f9ea137dc0 100644 --- a/sql/types/number_test.go +++ b/sql/types/number_test.go @@ -1561,14 +1561,6 @@ func TestConvertValueToUint64(t *testing.T) { exp: 123, rng: sql.InRange, }, - { - val: sql.Value{ - Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(math.MaxFloat32)), - Typ: sqltypes.Float32, - }, - exp: 0, - rng: sql.OutOfRange, - }, { val: sql.Value{ Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(math.MaxFloat32)), From 74d03cf3cf01e7565d8f6e2c83b194dbd3521300 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 29 Oct 2025 17:30:58 -0700 Subject: [PATCH 47/59] asdf --- sql/types/number_test.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/sql/types/number_test.go b/sql/types/number_test.go index f9ea137dc0..9980284e22 100644 --- a/sql/types/number_test.go +++ b/sql/types/number_test.go @@ -1239,7 +1239,6 @@ func TestConvertValueToUint64(t *testing.T) { zeroDec := serializeDecimal(decimal.Zero) testDec := serializeDecimal(decimal.NewFromFloat(123.456)) - minInt64Dec := serializeDecimal(decimal.NewFromInt(math.MinInt64)) maxInt64Dec := serializeDecimal(decimal.NewFromInt(math.MaxInt64)) tests := []struct { @@ -1621,14 +1620,6 @@ func TestConvertValueToUint64(t *testing.T) { exp: 123, rng: sql.InRange, }, - { - val: sql.Value{ - Val: minInt64Dec, - Typ: sqltypes.Decimal, - }, - exp: math.MaxUint64, - rng: sql.OutOfRange, - }, { val: sql.Value{ Val: maxInt64Dec, From 9b5519b7e0b8e7e7c04672f61afd973a5cb065e7 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 30 Oct 2025 11:03:20 -0700 Subject: [PATCH 48/59] fix panic --- sql/table_iter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/table_iter.go b/sql/table_iter.go index 8df5058274..1c82ac05ad 100644 --- a/sql/table_iter.go +++ b/sql/table_iter.go @@ -108,7 +108,7 @@ func (i *TableRowIter) NextValueRow(ctx *Context) (ValueRow, error) { row, err := i.valueRows.NextValueRow(ctx) if err != nil && err == io.EOF { - if err = i.rows.Close(ctx); err != nil { + if err = i.valueRows.Close(ctx); err != nil { return nil, err } i.partition = nil From 6904da1d8daccd2f1543e3e969df9a89a9a2f1fe Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 30 Oct 2025 11:42:30 -0700 Subject: [PATCH 49/59] decimal conversion tests --- sql/types/decimal.go | 10 +- sql/types/decimal_test.go | 455 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 462 insertions(+), 3 deletions(-) diff --git a/sql/types/decimal.go b/sql/types/decimal.go index ebf14a8d3e..a9690f4c42 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -144,11 +144,11 @@ func (t DecimalType_) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error if hasNulls, res := CompareNullValues(a, b); hasNulls { return res, nil } - aDec, err := ConvertValueToDecimal(ctx, a) + aDec, err := convertValueToDecimal(ctx, a) if err != nil { return 0, err } - bDec, err := ConvertValueToDecimal(ctx, b) + bDec, err := convertValueToDecimal(ctx, b) if err != nil { return 0, err } @@ -412,7 +412,7 @@ func (t DecimalType_) DecimalValueStringFixed(v decimal.Decimal) string { } } -func ConvertValueToDecimal(ctx *sql.Context, v sql.Value) (decimal.Decimal, error) { +func convertValueToDecimal(ctx *sql.Context, v sql.Value) (decimal.Decimal, error) { switch v.Typ { case sqltypes.Int8: x := values.ReadInt8(v.Val) @@ -448,6 +448,10 @@ func ConvertValueToDecimal(ctx *sql.Context, v sql.Value) (decimal.Decimal, erro case sqltypes.Decimal: x := values.ReadDecimal(v.Val) return x, nil + case sqltypes.Bit: + x := values.ReadUint64(v.Val) + bi := new(big.Int).SetUint64(x) + return decimal.NewFromBigInt(bi, 0), nil case sqltypes.Year: x := values.ReadUint16(v.Val) return decimal.NewFromInt(int64(x)), nil diff --git a/sql/types/decimal_test.go b/sql/types/decimal_test.go index e39e3496b6..90b53467f7 100644 --- a/sql/types/decimal_test.go +++ b/sql/types/decimal_test.go @@ -16,7 +16,10 @@ package types import ( "context" + "encoding/binary" "fmt" + "github.com/dolthub/vitess/go/sqltypes" + "math" "math/big" "reflect" "strings" @@ -426,3 +429,455 @@ func TestDecimalZero(t *testing.T) { }) } } + +func TestConvertValueToDecimal(t *testing.T) { + ctx := sql.NewEmptyContext() + + zeroDec := serializeDecimal(decimal.Zero) + testDec := serializeDecimal(decimal.NewFromFloat(123.456)) + minInt64Dec := serializeDecimal(decimal.NewFromInt(math.MinInt64)) + maxInt64Dec := serializeDecimal(decimal.NewFromInt(math.MaxInt64)) + + tests := []struct { + val sql.Value + exp decimal.Decimal + err bool + }{ + // Int8 -> Decimal + { + val: sql.Value{ + Val: []byte{0}, + Typ: sqltypes.Int8, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: []byte{67}, + Typ: sqltypes.Int8, + }, + exp: decimal.NewFromInt(67), + }, + { + val: sql.Value{ + Val: []byte{127}, + Typ: sqltypes.Int8, + }, + exp: decimal.NewFromInt(127), + }, + { + val: sql.Value{ + Val: []byte{255}, + Typ: sqltypes.Int8, + }, + exp: decimal.NewFromInt(-1), + }, + + // Int16 -> Decimal + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Int16, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(67)), + Typ: sqltypes.Int16, + }, + exp: decimal.NewFromInt(67), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16), + Typ: sqltypes.Int16, + }, + exp: decimal.NewFromInt(math.MaxInt16), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16+1), + Typ: sqltypes.Int16, + }, + exp: decimal.NewFromInt(math.MinInt16), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxUint16), + Typ: sqltypes.Int16, + }, + exp: decimal.NewFromInt(-1), + }, + + // Int32 -> Decimal + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(0)), + Typ: sqltypes.Int32, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(67)), + Typ: sqltypes.Int32, + }, + exp: decimal.NewFromInt(67), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32), + Typ: sqltypes.Int32, + }, + exp: decimal.NewFromInt(math.MaxInt32), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32+1), + Typ: sqltypes.Int32, + }, + exp: decimal.NewFromInt(math.MinInt32), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxUint32), + Typ: sqltypes.Int32, + }, + exp: decimal.NewFromInt(-1), + }, + + // Int64 -> Decimal + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Int64, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Int64, + }, + exp: decimal.NewFromInt(67), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Int64, + }, + exp: decimal.NewFromInt(math.MaxInt64), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64+1), + Typ: sqltypes.Int64, + }, + exp: decimal.NewFromInt(math.MinInt64), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Int64, + }, + exp: decimal.NewFromInt(-1), + }, + + // Uint8 -> Decimal + { + val: sql.Value{ + Val: []byte{0}, + Typ: sqltypes.Uint8, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: []byte{67}, + Typ: sqltypes.Uint8, + }, + exp: decimal.NewFromInt(67), + }, + { + val: sql.Value{ + Val: []byte{128}, + Typ: sqltypes.Uint8, + }, + exp: decimal.NewFromInt(128), + }, + { + val: sql.Value{ + Val: []byte{255}, + Typ: sqltypes.Uint8, + }, + exp: decimal.NewFromInt(255), + }, + + // Uint16 -> Decimal + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Uint16, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(67)), + Typ: sqltypes.Uint16, + }, + exp: decimal.NewFromInt(67), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16), + Typ: sqltypes.Uint16, + }, + exp: decimal.NewFromInt(math.MaxInt16), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxUint16), + Typ: sqltypes.Uint16, + }, + exp: decimal.NewFromInt(math.MaxUint16), + }, + + // Uint32 -> Decimal + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(0)), + Typ: sqltypes.Uint32, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(67)), + Typ: sqltypes.Uint32, + }, + exp: decimal.NewFromInt(67), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32), + Typ: sqltypes.Uint32, + }, + exp: decimal.NewFromInt(math.MaxInt32), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxUint32), + Typ: sqltypes.Uint32, + }, + exp: decimal.NewFromInt(math.MaxUint32), + }, + + // Uint64 -> Decimal + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Uint64, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Uint64, + }, + exp: decimal.NewFromInt(67), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Uint64, + }, + exp: decimal.NewFromInt(math.MaxInt64), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Uint64, + }, + exp: decimal.NewFromBigInt(new(big.Int).SetUint64(math.MaxUint64), 0), + }, + + // Float32 -> Decimal + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(0)), + Typ: sqltypes.Float32, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(123.456)), + Typ: sqltypes.Float32, + }, + exp: decimal.NewFromFloat32(123.456), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(-math.MaxFloat32)), + Typ: sqltypes.Float32, + }, + exp: decimal.NewFromFloat32(-math.MaxFloat32), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(math.MaxFloat32)), + Typ: sqltypes.Float32, + }, + exp: decimal.NewFromFloat32(math.MaxFloat32), + }, + + // Float64 -> Decimal + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(0)), + Typ: sqltypes.Float64, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(123.456)), + Typ: sqltypes.Float64, + }, + exp: decimal.NewFromFloat(123.456), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(-math.MaxFloat32)), + Typ: sqltypes.Float64, + }, + exp: decimal.NewFromFloat(-math.MaxFloat32), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(math.MaxFloat32)), + Typ: sqltypes.Float64, + }, + exp: decimal.NewFromFloat(math.MaxFloat32), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(-math.MaxFloat64)), + Typ: sqltypes.Float64, + }, + exp: decimal.NewFromFloat(-math.MaxFloat64), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(math.MaxFloat64)), + Typ: sqltypes.Float64, + }, + exp: decimal.NewFromFloat(math.MaxFloat64), + }, + + // Decimal -> Decimal + { + val: sql.Value{ + Val: zeroDec, + Typ: sqltypes.Decimal, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: testDec, + Typ: sqltypes.Decimal, + }, + exp: decimal.NewFromFloat(123.456), + }, + { + val: sql.Value{ + Val: minInt64Dec, + Typ: sqltypes.Decimal, + }, + exp: decimal.NewFromInt(math.MinInt64), + }, + { + val: sql.Value{ + Val: maxInt64Dec, + Typ: sqltypes.Decimal, + }, + exp: decimal.NewFromInt(math.MaxInt64), + }, + + // Bit -> Decimal + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Bit, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Bit, + }, + exp: decimal.NewFromInt(67), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Bit, + }, + exp: decimal.NewFromInt(math.MaxInt64), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Bit, + }, + exp: decimal.NewFromBigInt(new(big.Int).SetUint64(math.MaxUint64), 0), + }, + + // Year -> Decimal + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Year, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(1967)), + Typ: sqltypes.Year, + }, + exp: decimal.NewFromInt(1967), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(1901)), + Typ: sqltypes.Year, + }, + exp: decimal.NewFromInt(1901), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(2155)), + Typ: sqltypes.Year, + }, + exp: decimal.NewFromInt(2155), + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%v", test.val), func(t *testing.T) { + res, err := convertValueToDecimal(ctx, test.val) + if test.err { + require.Error(t, err) + return + } + require.NoError(t, err) + require.True(t, test.exp.Equal(res), fmt.Sprintf("%v != %v", test.exp, res)) + }) + } +} From 51d3e90105785c5b3aa01abf9a265441c76204ec Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 30 Oct 2025 11:58:26 -0700 Subject: [PATCH 50/59] fix --- sql/types/datetime.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/types/datetime.go b/sql/types/datetime.go index 3dfdc6382c..28407d3d29 100644 --- a/sql/types/datetime.go +++ b/sql/types/datetime.go @@ -488,7 +488,7 @@ func (t datetimeType) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqlt switch t.baseType { case sqltypes.Date: // TODO: move this to values package - t := values.ReadDatetime(v.Val) + t := values.ReadDate(v.Val) dest = t.AppendFormat(dest, sql.DateLayout) case sqltypes.Datetime, sqltypes.Timestamp: From aa0d6d03549355dbdf4ce172eb1dedacd94dd5b7 Mon Sep 17 00:00:00 2001 From: jycor Date: Thu, 30 Oct 2025 19:01:53 +0000 Subject: [PATCH 51/59] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/types/decimal_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/types/decimal_test.go b/sql/types/decimal_test.go index 90b53467f7..e12e1c2e53 100644 --- a/sql/types/decimal_test.go +++ b/sql/types/decimal_test.go @@ -18,7 +18,6 @@ import ( "context" "encoding/binary" "fmt" - "github.com/dolthub/vitess/go/sqltypes" "math" "math/big" "reflect" @@ -26,11 +25,12 @@ import ( "testing" "time" - "github.com/dolthub/go-mysql-server/sql" - + "github.com/dolthub/vitess/go/sqltypes" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/dolthub/go-mysql-server/sql" ) func TestDecimalAccuracy(t *testing.T) { From 13556b0e894d8291fdb353440163d4253555eece Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 30 Oct 2025 12:21:16 -0700 Subject: [PATCH 52/59] comparison microbenchmarks --- sql/expression/comparison_test.go | 79 +++++++++++++++++++++++-------- 1 file changed, 58 insertions(+), 21 deletions(-) diff --git a/sql/expression/comparison_test.go b/sql/expression/comparison_test.go index ec904b7c62..2d788561e8 100644 --- a/sql/expression/comparison_test.go +++ b/sql/expression/comparison_test.go @@ -12,15 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -package expression_test +package expression import ( + "encoding/binary" + "github.com/dolthub/vitess/go/sqltypes" "testing" "github.com/stretchr/testify/require" "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/types" ) @@ -109,11 +110,11 @@ var likeComparisonCases = map[sql.Type]map[int][][]interface{}{ func TestEquals(t *testing.T) { require := require.New(t) for resultType, cmpCase := range comparisonCases { - get0 := expression.NewGetField(0, resultType, "col1", true) + get0 := NewGetField(0, resultType, "col1", true) require.NotNil(get0) - get1 := expression.NewGetField(1, resultType, "col2", true) + get1 := NewGetField(1, resultType, "col2", true) require.NotNil(get1) - eq := expression.NewEquals(get0, get1) + eq := NewEquals(get0, get1) require.NotNil(eq) require.Equal(types.Boolean, eq.Type()) for cmpResult, cases := range cmpCase { @@ -136,11 +137,11 @@ func TestEquals(t *testing.T) { func TestNullSafeEquals(t *testing.T) { require := require.New(t) for resultType, cmpCase := range comparisonCases { - get0 := expression.NewGetField(0, resultType, "col1", true) + get0 := NewGetField(0, resultType, "col1", true) require.NotNil(get0) - get1 := expression.NewGetField(1, resultType, "col2", true) + get1 := NewGetField(1, resultType, "col2", true) require.NotNil(get1) - seq := expression.NewNullSafeEquals(get0, get1) + seq := NewNullSafeEquals(get0, get1) require.NotNil(seq) require.Equal(types.Boolean, seq.Type()) for cmpResult, cases := range cmpCase { @@ -167,11 +168,11 @@ func TestNullSafeEquals(t *testing.T) { func TestLessThan(t *testing.T) { require := require.New(t) for resultType, cmpCase := range comparisonCases { - get0 := expression.NewGetField(0, resultType, "col1", true) + get0 := NewGetField(0, resultType, "col1", true) require.NotNil(get0) - get1 := expression.NewGetField(1, resultType, "col2", true) + get1 := NewGetField(1, resultType, "col2", true) require.NotNil(get1) - eq := expression.NewLessThan(get0, get1) + eq := NewLessThan(get0, get1) require.NotNil(eq) require.Equal(types.Boolean, eq.Type()) for cmpResult, cases := range cmpCase { @@ -194,11 +195,11 @@ func TestLessThan(t *testing.T) { func TestGreaterThan(t *testing.T) { require := require.New(t) for resultType, cmpCase := range comparisonCases { - get0 := expression.NewGetField(0, resultType, "col1", true) + get0 := NewGetField(0, resultType, "col1", true) require.NotNil(get0) - get1 := expression.NewGetField(1, resultType, "col2", true) + get1 := NewGetField(1, resultType, "col2", true) require.NotNil(get1) - eq := expression.NewGreaterThan(get0, get1) + eq := NewGreaterThan(get0, get1) require.NotNil(eq) require.Equal(types.Boolean, eq.Type()) for cmpResult, cases := range cmpCase { @@ -218,13 +219,49 @@ func TestGreaterThan(t *testing.T) { } } -func eval(t *testing.T, e sql.Expression, row sql.Row) interface{} { - t.Helper() - v, err := e.Eval(sql.NewEmptyContext(), row) - require.NoError(t, err) - return v -} - func TestValueComparison(t *testing.T) { // TODO } + +// BenchmarkComparison +// BenchmarkComparison-14 4426766 264.4 ns/op +func BenchmarkComparison(b *testing.B) { + ctx := sql.NewEmptyContext() + gf1 := NewGetField(0, types.Int64, "col1", true) + gf2 := NewGetField(1, types.Int64, "col2", true) + cmp := newComparison(gf1, gf2) + row := sql.Row{1, 1} + b.ResetTimer() + + for i := 0; i < b.N; i++ { + res, err := cmp.Compare(ctx, row) + require.NoError(b, err) + require.Equal(b, 0, res) + } +} + +// BenchmarkValueComparison +// BenchmarkValueComparison-14 4115744 285.8 ns/op +func BenchmarkValueComparison(b *testing.B) { + ctx := sql.NewEmptyContext() + gf1 := NewGetField(0, types.Int64, "col1", true) + gf2 := NewGetField(1, types.Int64, "col2", true) + cmp := newComparison(gf1, gf2) + row := sql.ValueRow{ + { + Val: binary.LittleEndian.AppendUint64(nil, uint64(1)), + Typ: sqltypes.Int64, + }, + { + Val: binary.LittleEndian.AppendUint64(nil, uint64(1)), + Typ: sqltypes.Int64, + }, + } + b.ResetTimer() + + for i := 0; i < b.N; i++ { + res, err := cmp.CompareValue(ctx, row) + require.NoError(b, err) + require.Equal(b, 0, res) + } +} From 31b3a45ef639191d7c4ef195bd92c4ff88b3c9db Mon Sep 17 00:00:00 2001 From: jycor Date: Thu, 30 Oct 2025 19:26:23 +0000 Subject: [PATCH 53/59] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/expression/comparison_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/expression/comparison_test.go b/sql/expression/comparison_test.go index 2d788561e8..9bc73b7ca9 100644 --- a/sql/expression/comparison_test.go +++ b/sql/expression/comparison_test.go @@ -16,9 +16,9 @@ package expression import ( "encoding/binary" - "github.com/dolthub/vitess/go/sqltypes" "testing" + "github.com/dolthub/vitess/go/sqltypes" "github.com/stretchr/testify/require" "github.com/dolthub/go-mysql-server/sql" From 2ed39f8f8d247644d06f754ffe1670f03db85545 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 30 Oct 2025 12:36:37 -0700 Subject: [PATCH 54/59] aaa --- sql/analyzer/replace_sort.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/analyzer/replace_sort.go b/sql/analyzer/replace_sort.go index 7a2e4b3ff7..534bd48b93 100644 --- a/sql/analyzer/replace_sort.go +++ b/sql/analyzer/replace_sort.go @@ -175,12 +175,12 @@ func replaceIdxSortHelper(ctx *sql.Context, scope *plan.Scope, node sql.Node, so sortFields[i] = sortField } else { sameSortFields = false - col2, _ := col.(sql.Expression2) + valCol, _ := col.(sql.ValueExpression) sortFields[i] = sql.SortField{ - Column: col, - Column2: col2, - NullOrdering: sortField.NullOrdering, - Order: sortField.Order, + Column: col, + ValueExprColumn: valCol, + NullOrdering: sortField.NullOrdering, + Order: sortField.Order, } } } From b9da4b429dc0e86537b01b896a02d87b1ec4d18b Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 30 Oct 2025 15:14:56 -0700 Subject: [PATCH 55/59] opt --- go.mod | 2 +- go.sum | 4 ++-- sql/types/datetime.go | 3 --- sql/types/decimal.go | 11 ++++------- sql/types/decimal_test.go | 4 ++-- sql/types/set.go | 2 +- sql/types/time.go | 34 +++++++++++++++++++++++++++++++--- sql/values/encoding.go | 23 ----------------------- 8 files changed, 41 insertions(+), 42 deletions(-) diff --git a/go.mod b/go.mod index 692abdee79..68d8f88fd4 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/lestrrat-go/strftime v1.0.4 github.com/pkg/errors v0.9.1 github.com/pmezard/go-difflib v1.0.0 - github.com/shopspring/decimal v1.3.1 + github.com/shopspring/decimal v1.4.0 github.com/sirupsen/logrus v1.8.1 github.com/stretchr/testify v1.9.0 go.opentelemetry.io/otel v1.31.0 diff --git a/go.sum b/go.sum index bab2369fe9..ee3d7c83d5 100644 --- a/go.sum +++ b/go.sum @@ -66,8 +66,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= -github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/sql/types/datetime.go b/sql/types/datetime.go index 28407d3d29..c10dc759fb 100644 --- a/sql/types/datetime.go +++ b/sql/types/datetime.go @@ -487,15 +487,12 @@ func (t datetimeType) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqlt } switch t.baseType { case sqltypes.Date: - // TODO: move this to values package t := values.ReadDate(v.Val) dest = t.AppendFormat(dest, sql.DateLayout) - case sqltypes.Datetime, sqltypes.Timestamp: x := values.ReadInt64(v.Val) t := time.UnixMicro(x).UTC() dest = t.AppendFormat(dest, sql.TimestampDatetimeLayout) - default: return sqltypes.Value{}, sql.ErrInvalidBaseType.New(t.baseType.String(), "datetime") } diff --git a/sql/types/decimal.go b/sql/types/decimal.go index a9690f4c42..ea51479d78 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -216,7 +216,7 @@ func (t DecimalType_) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal, case int64: return t.ConvertToNullDecimal(decimal.NewFromInt(value)) case uint64: - return t.ConvertToNullDecimal(decimal.NewFromBigInt(new(big.Int).SetUint64(value), 0)) + return t.ConvertToNullDecimal(decimal.NewFromUint64(value)) case float32: return t.ConvertToNullDecimal(decimal.NewFromFloat32(value)) case float64: @@ -351,8 +351,7 @@ func (t DecimalType_) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqlt return sqltypes.NULL, nil } d := values.ReadDecimal(v.Val) - val := AppendAndSliceString(dest, t.DecimalValueStringFixed(d)) - return sqltypes.MakeTrusted(sqltypes.Decimal, val), nil + return sqltypes.MakeTrusted(sqltypes.Decimal, []byte(t.DecimalValueStringFixed(d))), nil } // String implements Type interface. @@ -437,8 +436,7 @@ func convertValueToDecimal(ctx *sql.Context, v sql.Value) (decimal.Decimal, erro return decimal.NewFromInt(int64(x)), nil case sqltypes.Uint64: x := values.ReadUint64(v.Val) - bi := new(big.Int).SetUint64(x) - return decimal.NewFromBigInt(bi, 0), nil + return decimal.NewFromUint64(x), nil case sqltypes.Float32: x := values.ReadFloat32(v.Val) return decimal.NewFromFloat32(x), nil @@ -450,8 +448,7 @@ func convertValueToDecimal(ctx *sql.Context, v sql.Value) (decimal.Decimal, erro return x, nil case sqltypes.Bit: x := values.ReadUint64(v.Val) - bi := new(big.Int).SetUint64(x) - return decimal.NewFromBigInt(bi, 0), nil + return decimal.NewFromUint64(x), nil case sqltypes.Year: x := values.ReadUint16(v.Val) return decimal.NewFromInt(int64(x)), nil diff --git a/sql/types/decimal_test.go b/sql/types/decimal_test.go index e12e1c2e53..740d0a1dc0 100644 --- a/sql/types/decimal_test.go +++ b/sql/types/decimal_test.go @@ -701,7 +701,7 @@ func TestConvertValueToDecimal(t *testing.T) { Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), Typ: sqltypes.Uint64, }, - exp: decimal.NewFromBigInt(new(big.Int).SetUint64(math.MaxUint64), 0), + exp: decimal.NewFromUint64(math.MaxUint64), }, // Float32 -> Decimal @@ -835,7 +835,7 @@ func TestConvertValueToDecimal(t *testing.T) { Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), Typ: sqltypes.Bit, }, - exp: decimal.NewFromBigInt(new(big.Int).SetUint64(math.MaxUint64), 0), + exp: decimal.NewFromUint64(math.MaxUint64), }, // Year -> Decimal diff --git a/sql/types/set.go b/sql/types/set.go index ccb642423b..6bba11ac11 100644 --- a/sql/types/set.go +++ b/sql/types/set.go @@ -285,7 +285,7 @@ func (t SetType) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes. } // TODO: write append style encoder - res, ok := resultCharset.Encoder().Encode(encodings.StringToBytes(value)) // TODO: use unsafe string to byte + res, ok := resultCharset.Encoder().Encode([]byte(value)) if !ok { if len(value) > 50 { value = value[:50] diff --git a/sql/types/time.go b/sql/types/time.go index fb4e153fc1..11c272ccc4 100644 --- a/sql/types/time.go +++ b/sql/types/time.go @@ -279,9 +279,8 @@ func (t TimespanType_) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sql return sqltypes.NULL, nil } x := values.ReadInt64(v.Val) - // TODO: write version of this that takes advantage of dest - v.Val = Timespan(x).Bytes() - return sqltypes.MakeTrusted(sqltypes.Time, v.Val), nil + dest = Timespan(x).AppendBytes(dest) + return sqltypes.MakeTrusted(sqltypes.Time, dest), nil } // String implements Type interface. @@ -502,6 +501,35 @@ func (t Timespan) Bytes() []byte { return ret[:i] } +func (t Timespan) AppendBytes(dest []byte) []byte { + isNegative, hours, minutes, seconds, microseconds := t.timespanToUnits() + sz := 10 + if microseconds > 0 { + sz += 7 + } + + i := 0 + if isNegative { + dest = append(dest, '-') + i++ + } + + i = appendDigit(int64(hours), 2, dest, i) + dest[i] = ':' + i++ + i = appendDigit(int64(minutes), 2, dest, i) + dest[i] = ':' + i++ + i = appendDigit(int64(seconds), 2, dest, i) + if microseconds > 0 { + dest[i] = '.' + i++ + i = appendDigit(int64(microseconds), 6, dest, i) + } + + return dest[:i] +} + // appendDigit format prints 0-entended integer into buffer func appendDigit(v int64, extend int, buf []byte, i int) int { cmp := int64(1) diff --git a/sql/values/encoding.go b/sql/values/encoding.go index 20d904fe98..2cf4f2a765 100644 --- a/sql/values/encoding.go +++ b/sql/values/encoding.go @@ -250,29 +250,6 @@ func WriteUint16(buf []byte, val uint16) []byte { return buf } -func WriteInt24(buf []byte, val int32) []byte { - expectSize(buf, Int24Size) - - var tmp [4]byte - binary.LittleEndian.PutUint32(tmp[:], uint32(val)) - // copy |tmp| to |buf| - buf[2], buf[1], buf[0] = tmp[2], tmp[1], tmp[0] - return buf -} - -func WriteUint24(buf []byte, val uint32) []byte { - expectSize(buf, Uint24Size) - if val > maxUint24 { - panic("uint is greater than max uint24") - } - - var tmp [4]byte - binary.LittleEndian.PutUint32(tmp[:], uint32(val)) - // copy |tmp| to |buf| - buf[2], buf[1], buf[0] = tmp[2], tmp[1], tmp[0] - return buf -} - func WriteInt32(buf []byte, val int32) []byte { expectSize(buf, Int32Size) binary.LittleEndian.PutUint32(buf, uint32(val)) From cbcb957e93cb60864c7572202deec4ec939a79f6 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 30 Oct 2025 16:09:49 -0700 Subject: [PATCH 56/59] bytes from string --- sql/types/enum.go | 2 +- sql/types/time.go | 41 +++++++++++++++++++++++++---------------- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/sql/types/enum.go b/sql/types/enum.go index 680ebed0c4..5db9cfde81 100644 --- a/sql/types/enum.go +++ b/sql/types/enum.go @@ -289,7 +289,7 @@ func (t EnumType) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes } // TODO: write append style encoder - res, ok := charset.Encoder().Encode(encodings.StringToBytes(value)) // TODO: use unsafe string to byte + res, ok := charset.Encoder().Encode([]byte(value)) if !ok { if len(value) > 50 { value = value[:50] diff --git a/sql/types/time.go b/sql/types/time.go index 11c272ccc4..29ac916948 100644 --- a/sql/types/time.go +++ b/sql/types/time.go @@ -507,30 +507,39 @@ func (t Timespan) AppendBytes(dest []byte) []byte { if microseconds > 0 { sz += 7 } - - i := 0 if isNegative { dest = append(dest, '-') - i++ } - i = appendDigit(int64(hours), 2, dest, i) - dest[i] = ':' - i++ - i = appendDigit(int64(minutes), 2, dest, i) - dest[i] = ':' - i++ - i = appendDigit(int64(seconds), 2, dest, i) - if microseconds > 0 { - dest[i] = '.' - i++ - i = appendDigit(int64(microseconds), 6, dest, i) + if hours < 10 { + dest = append(dest, '0') } + dest = strconv.AppendInt(dest, int64(hours), 10) + dest = append(dest, ':') - return dest[:i] + if minutes < 10 { + dest = append(dest, '0') + } + dest = strconv.AppendInt(dest, int64(minutes), 10) + dest = append(dest, ':') + + if seconds < 10 { + dest = append(dest, '0') + } + dest = strconv.AppendInt(dest, int64(seconds), 10) + if microseconds > 0 { + dest = append(dest, '.') + cmp := int32(100000) + for cmp > 0 && microseconds < cmp { + dest = append(dest, '0') + cmp /= 10 + } + dest = strconv.AppendInt(dest, int64(microseconds), 10) + } + return dest } -// appendDigit format prints 0-entended integer into buffer +// appendDigit format prints 0-extended integer into buffer func appendDigit(v int64, extend int, buf []byte, i int) int { cmp := int64(1) for _ = range extend - 1 { From d5eb7d90d2d7eefad837ae75871204a7866edc8a Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 31 Oct 2025 12:43:24 -0700 Subject: [PATCH 57/59] comment --- sql/types/bit.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/types/bit.go b/sql/types/bit.go index ce24772911..50a5b71cc3 100644 --- a/sql/types/bit.go +++ b/sql/types/bit.go @@ -252,9 +252,8 @@ func (t BitType_) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes } v.Val = v.Val[:numBytes] - // TODO: for whatever reason TestTypesOverWire only works when this is a deep copy? - dest = append(dest, v.Val...) // want the results in big endian + dest = append(dest, v.Val...) for i, j := 0, len(dest)-1; i < j; i, j = i+1, j-1 { dest[i], dest[j] = dest[j], dest[i] } From 21ed9ffe903dd1fc61d557861a789eb2cad478cc Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 31 Oct 2025 13:11:11 -0700 Subject: [PATCH 58/59] feedback --- sql/expression/get_field.go | 2 +- sql/memory.go | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index ec3d5115ac..205b7b505f 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -131,7 +131,7 @@ func (p *GetField) Type() sql.Type { } // ErrIndexOutOfBounds is returned when the field index is out of the bounds. -var ErrIndexOutOfBounds = errors.NewKind("unable to find field with index %d in row of %d columns") +var ErrIndexOutOfBounds = errors.NewKind("unable to find field with index %d in row of %d columns. \n This is a bug. Please file an issue here: https://github.com/dolthub/dolt/issues") // Eval implements the Expression interface. func (p *GetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { diff --git a/sql/memory.go b/sql/memory.go index fd0eff631c..1bc49bab1b 100644 --- a/sql/memory.go +++ b/sql/memory.go @@ -64,14 +64,14 @@ type RowsCache interface { Get() []Row } -// Rows2Cache is a cache of Row2s. -type Rows2Cache interface { +// ValueRowsCache is a cache of ValueRows. +type ValueRowsCache interface { RowsCache - // Add2 a new row to the cache. If there is no memory available, it will try to + // AddValueRow a new row to the cache. If there is no memory available, it will try to // free some memory. If after that there is still no memory available, it // will return an error and erase all the content of the cache. AddValueRow(ValueRow) error - // Get2 gets all rows. + // GetValueRow gets all rows. GetValueRow() []ValueRow } @@ -200,7 +200,7 @@ func (m *MemoryManager) NewRowsCache() (RowsCache, DisposeFunc) { // NewRowsCache returns an empty rows cache and a function to dispose it when it's // no longer needed. -func (m *MemoryManager) NewRows2Cache() (Rows2Cache, DisposeFunc) { +func (m *MemoryManager) NewRows2Cache() (ValueRowsCache, DisposeFunc) { c := newRowsCache(m, m.reporter) pos := m.addCache(c) return c, func() { From f6063a114a045c2bebcba0c9b4738cc2eea776ca Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 31 Oct 2025 14:25:05 -0700 Subject: [PATCH 59/59] rest of feedback --- sql/expression/comparison_test.go | 2 +- sql/plan/indexed_table_access.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/expression/comparison_test.go b/sql/expression/comparison_test.go index 9bc73b7ca9..710333d121 100644 --- a/sql/expression/comparison_test.go +++ b/sql/expression/comparison_test.go @@ -220,7 +220,7 @@ func TestGreaterThan(t *testing.T) { } func TestValueComparison(t *testing.T) { - // TODO + t.Skip("TODO: write tests for comparison between sql.Values") } // BenchmarkComparison diff --git a/sql/plan/indexed_table_access.go b/sql/plan/indexed_table_access.go index d728a7a1c4..efb54cad54 100644 --- a/sql/plan/indexed_table_access.go +++ b/sql/plan/indexed_table_access.go @@ -307,7 +307,7 @@ func (i *IndexedTableAccess) GetLookup(ctx *sql.Context, row sql.Row) (sql.Index return i.lb.GetLookup(ctx, key) } -func (i *IndexedTableAccess) getLookup2(ctx *sql.Context, row sql.ValueRow) (sql.IndexLookup, error) { +func (i *IndexedTableAccess) getValueLookup(ctx *sql.Context, row sql.ValueRow) (sql.IndexLookup, error) { // if the lookup was provided at analysis time (static evaluation), use it. if !i.lookup.IsEmpty() { return i.lookup, nil