Skip to content

Commit f0765b0

Browse files
author
James Cor
committed
a ton of conversions
1 parent d7f4e46 commit f0765b0

File tree

14 files changed

+819
-263
lines changed

14 files changed

+819
-263
lines changed

sql/expression/comparison.go

Lines changed: 121 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ package expression
1717
import (
1818
"fmt"
1919

20-
"github.com/dolthub/vitess/go/sqltypes"
2120
"gopkg.in/src-d/go-errors.v1"
2221

2322
"github.com/dolthub/go-mysql-server/sql"
@@ -158,6 +157,59 @@ func (c *comparison) Compare(ctx *sql.Context, row sql.Row) (int, error) {
158157
return compareType.Compare(ctx, l, r)
159158
}
160159

160+
// CompareValue the two given values using the types of the expressions in the comparison.
161+
func (c *comparison) CompareValue(ctx *sql.Context, row sql.ValueRow) (int, error) {
162+
// TODO: avoid type assertions
163+
lv, err := c.LeftChild.(sql.ValueExpression).EvalValue(ctx, row)
164+
if err != nil {
165+
return 0, err
166+
}
167+
rv, err := c.RightChild.(sql.ValueExpression).EvalValue(ctx, row)
168+
if err != nil {
169+
return 0, err
170+
}
171+
172+
if lv.IsNull() || rv.IsNull() {
173+
return 0, nil
174+
}
175+
176+
lTyp, rTyp := c.LeftChild.Type().(sql.ValueType), c.RightChild.Type().(sql.ValueType)
177+
if types.TypesEqual(lTyp, rTyp) {
178+
return lTyp.(sql.ValueType).CompareValue(ctx, lv, rv)
179+
}
180+
181+
// TODO: enums
182+
183+
// TODO: sets
184+
185+
if types.IsNumber(lTyp) || types.IsNumber(rTyp) {
186+
if types.IsUnsigned(lTyp) && types.IsUnsigned(rTyp) {
187+
return types.Uint64.(sql.ValueType).CompareValue(ctx, lv, rv)
188+
}
189+
if types.IsSigned(lTyp) && types.IsSigned(rTyp) {
190+
return types.Int64.(sql.ValueType).CompareValue(ctx, lv, rv)
191+
}
192+
if types.IsDecimal(lTyp) || types.IsDecimal(rTyp) {
193+
return types.InternalDecimalType.(sql.ValueType).CompareValue(ctx, lv, rv)
194+
}
195+
return types.Float64.(sql.ValueType).CompareValue(ctx, lv, rv)
196+
}
197+
return lTyp.CompareValue(ctx, lv, rv)
198+
}
199+
200+
// IsValueExpression returns whether every child supports sql.ValueExpression
201+
func (c *comparison) IsValueExpression() bool {
202+
l, ok := c.LeftChild.(sql.ValueExpression)
203+
if !ok {
204+
return false
205+
}
206+
r, ok := c.RightChild.(sql.ValueExpression)
207+
if !ok {
208+
return false
209+
}
210+
return l.IsValueExpression() && r.IsValueExpression()
211+
}
212+
161213
func (c *comparison) evalLeftAndRight(ctx *sql.Context, row sql.Row) (interface{}, interface{}, error) {
162214
left, err := c.Left().Eval(ctx, row)
163215
if err != nil {
@@ -520,71 +572,6 @@ func (gt *GreaterThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error)
520572
return result == 1, nil
521573
}
522574

523-
// EvalValue implements the sql.ValueExpression interface.
524-
func (gt *GreaterThan) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) {
525-
lv, err := gt.comparison.LeftChild.(sql.ValueExpression).EvalValue(ctx, row)
526-
if err != nil {
527-
return sql.Value{}, err
528-
}
529-
rv, err := gt.comparison.RightChild.(sql.ValueExpression).EvalValue(ctx, row)
530-
if err != nil {
531-
return sql.Value{}, err
532-
}
533-
534-
// TODO: move this logic into comparison
535-
var cmp byte
536-
if sqltypes.IsUnsigned(lv.Typ) && sqltypes.IsUnsigned(rv.Typ) {
537-
l, cErr := types.ConvertValueToUint64(lv)
538-
if cErr != nil {
539-
return sql.Value{}, cErr
540-
}
541-
r, cErr := types.ConvertValueToUint64(rv)
542-
if cErr != nil {
543-
return sql.Value{}, cErr
544-
}
545-
if l > r {
546-
cmp = 1
547-
}
548-
} else {
549-
l, cErr := types.ConvertValueToInt64(lv)
550-
if cErr != nil {
551-
return sql.Value{}, cErr
552-
}
553-
r, cErr := types.ConvertValueToInt64(rv)
554-
if cErr != nil {
555-
return sql.Value{}, cErr
556-
}
557-
if l > r {
558-
cmp = 1
559-
}
560-
}
561-
562-
res := sql.Value{
563-
Val: []byte{cmp},
564-
Typ: sqltypes.Int8,
565-
}
566-
return res, nil
567-
}
568-
569-
// IsValueRowIter implements the ValueExpression interface.
570-
func (gt *GreaterThan) IsValueExpression() bool {
571-
l, ok := gt.comparison.LeftChild.(sql.ValueExpression)
572-
if !ok {
573-
return false
574-
}
575-
if !l.IsValueExpression() {
576-
return false
577-
}
578-
r, ok := gt.comparison.RightChild.(sql.ValueExpression)
579-
if !ok {
580-
return false
581-
}
582-
if !r.IsValueExpression() {
583-
return false
584-
}
585-
return true
586-
}
587-
588575
// WithChildren implements the Expression interface.
589576
func (gt *GreaterThan) WithChildren(children ...sql.Expression) (sql.Expression, error) {
590577
if len(children) != 2 {
@@ -605,6 +592,23 @@ func (gt *GreaterThan) DebugString() string {
605592
return pr.String()
606593
}
607594

595+
// EvalValue implements the sql.ValueExpression interface.
596+
func (gt *GreaterThan) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) {
597+
cmp, err := gt.CompareValue(ctx, row)
598+
if err != nil {
599+
return sql.NullValue, err
600+
}
601+
if cmp != 1 {
602+
return sql.FalseValue, nil
603+
}
604+
return sql.TrueValue, nil
605+
}
606+
607+
// IsValueExpression implements the ValueExpression interface.
608+
func (gt *GreaterThan) IsValueExpression() bool {
609+
return gt.comparison.IsValueExpression()
610+
}
611+
608612
// LessThan is a comparison that checks an expression is less than another.
609613
type LessThan struct {
610614
comparison
@@ -630,10 +634,8 @@ func (lt *LessThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
630634
if ErrNilOperand.Is(err) {
631635
return nil, nil
632636
}
633-
634637
return nil, err
635638
}
636-
637639
return result == -1, nil
638640
}
639641

@@ -657,6 +659,23 @@ func (lt *LessThan) DebugString() string {
657659
return pr.String()
658660
}
659661

662+
// EvalValue implements the sql.ValueExpression interface.
663+
func (lt *LessThan) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) {
664+
cmp, err := lt.CompareValue(ctx, row)
665+
if err != nil {
666+
return sql.NullValue, err
667+
}
668+
if cmp != -1 {
669+
return sql.FalseValue, nil
670+
}
671+
return sql.TrueValue, nil
672+
}
673+
674+
// IsValueExpression implements the ValueExpression interface.
675+
func (lt *LessThan) IsValueExpression() bool {
676+
return lt.comparison.IsValueExpression()
677+
}
678+
660679
// GreaterThanOrEqual is a comparison that checks an expression is greater or equal to
661680
// another.
662681
type GreaterThanOrEqual struct {
@@ -683,10 +702,8 @@ func (gte *GreaterThanOrEqual) Eval(ctx *sql.Context, row sql.Row) (interface{},
683702
if ErrNilOperand.Is(err) {
684703
return nil, nil
685704
}
686-
687705
return nil, err
688706
}
689-
690707
return result > -1, nil
691708
}
692709

@@ -710,6 +727,23 @@ func (gte *GreaterThanOrEqual) DebugString() string {
710727
return pr.String()
711728
}
712729

730+
// EvalValue implements the sql.ValueExpression interface.
731+
func (gte *GreaterThanOrEqual) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) {
732+
cmp, err := gte.CompareValue(ctx, row)
733+
if err != nil {
734+
return sql.NullValue, err
735+
}
736+
if cmp == -1 {
737+
return sql.FalseValue, nil
738+
}
739+
return sql.TrueValue, nil
740+
}
741+
742+
// IsValueExpression implements the ValueExpression interface.
743+
func (gte *GreaterThanOrEqual) IsValueExpression() bool {
744+
return gte.comparison.IsValueExpression()
745+
}
746+
713747
// LessThanOrEqual is a comparison that checks an expression is equal or lower than
714748
// another.
715749
type LessThanOrEqual struct {
@@ -763,6 +797,23 @@ func (lte *LessThanOrEqual) DebugString() string {
763797
return pr.String()
764798
}
765799

800+
// EvalValue implements the sql.ValueExpression interface.
801+
func (lte *LessThanOrEqual) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) {
802+
cmp, err := lte.CompareValue(ctx, row)
803+
if err != nil {
804+
return sql.NullValue, err
805+
}
806+
if cmp == 1 {
807+
return sql.FalseValue, nil
808+
}
809+
return sql.TrueValue, nil
810+
}
811+
812+
// IsValueExpression implements the ValueExpression interface.
813+
func (lte *LessThanOrEqual) IsValueExpression() bool {
814+
return lte.comparison.IsValueExpression()
815+
}
816+
766817
var (
767818
// ErrUnsupportedInOperand is returned when there is an invalid righthand
768819
// operand in an IN operator.

sql/type.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ type Type interface {
108108
// ValueType is an extension of the Type interface, that operates over sql.Values.
109109
type ValueType interface {
110110
Type
111+
// CompareValue returns an integer comparing two sql.Values.
112+
// The result will be 0 if a == b, -1 if a < b, and +1 if a > b.
113+
CompareValue(*Context, Value, Value) (int, error)
111114
// SQLValue returns the sqltypes.Value for the given sql.Value.
112115
// Implementations can optionally use |dest| to append
113116
// serialized data, but should not mutate existing data.

sql/types/bit.go

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,31 @@ func (t BitType_) Compare(ctx context.Context, a interface{}, b interface{}) (in
103103
return 0, nil
104104
}
105105

106+
// CompareValue implements the ValueType interface
107+
func (t BitType_) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error) {
108+
if hasNulls, res := CompareNullValues(a, b); hasNulls {
109+
return res, nil
110+
}
111+
112+
av, err := ConvertValueToUint64(ctx, a)
113+
if err != nil {
114+
return 0, err
115+
}
116+
bv, err := ConvertValueToUint64(ctx, b)
117+
if err != nil {
118+
return 0, err
119+
}
120+
121+
switch {
122+
case av < bv:
123+
return -1, nil
124+
case av > bv:
125+
return 1, nil
126+
default:
127+
return 0, nil
128+
}
129+
}
130+
106131
// Convert implements Type interface.
107132
func (t BitType_) Convert(ctx context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) {
108133
if v == nil {
@@ -211,7 +236,7 @@ func (t BitType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va
211236
return sqltypes.MakeTrusted(sqltypes.Bit, val), nil
212237
}
213238

214-
// ToSQLValue implements ValueType interface.
239+
// SQLValue implements ValueType interface.
215240
func (t BitType_) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) {
216241
if v.IsNull() {
217242
return sqltypes.NULL, nil

sql/types/conversion.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,24 @@ func CompareNulls(a interface{}, b interface{}) (bool, int) {
472472
return false, 0
473473
}
474474

475+
// CompareNullValues compares two sql.Values, and returns true if either is null.
476+
// The returned integer represents the ordering, with a rule that states nulls
477+
// as being ordered before non-nulls.
478+
func CompareNullValues(a, b sql.Value) (bool, int) {
479+
aIsNull := a.IsNull()
480+
bIsNull := b.IsNull()
481+
switch {
482+
case aIsNull && bIsNull:
483+
return true, 0
484+
case aIsNull && !bIsNull:
485+
return false, 1
486+
case !aIsNull && bIsNull:
487+
return false, -1
488+
default:
489+
return false, 0
490+
}
491+
}
492+
475493
// NumColumns returns the number of columns in a type. This is one for all
476494
// types, except tuples.
477495
func NumColumns(t sql.Type) int {

0 commit comments

Comments
 (0)