@@ -17,7 +17,6 @@ package expression
1717import (
1818 "fmt"
1919
20- "github.com/dolthub/vitess/go/sqltypes"
2120 "gopkg.in/src-d/go-errors.v1"
2221
2322 "github.com/dolthub/go-mysql-server/sql"
@@ -158,6 +157,59 @@ func (c *comparison) Compare(ctx *sql.Context, row sql.Row) (int, error) {
158157 return compareType .Compare (ctx , l , r )
159158}
160159
160+ // CompareValue the two given values using the types of the expressions in the comparison.
161+ func (c * comparison ) CompareValue (ctx * sql.Context , row sql.ValueRow ) (int , error ) {
162+ // TODO: avoid type assertions
163+ lv , err := c .LeftChild .(sql.ValueExpression ).EvalValue (ctx , row )
164+ if err != nil {
165+ return 0 , err
166+ }
167+ rv , err := c .RightChild .(sql.ValueExpression ).EvalValue (ctx , row )
168+ if err != nil {
169+ return 0 , err
170+ }
171+
172+ if lv .IsNull () || rv .IsNull () {
173+ return 0 , nil
174+ }
175+
176+ lTyp , rTyp := c .LeftChild .Type ().(sql.ValueType ), c .RightChild .Type ().(sql.ValueType )
177+ if types .TypesEqual (lTyp , rTyp ) {
178+ return lTyp .(sql.ValueType ).CompareValue (ctx , lv , rv )
179+ }
180+
181+ // TODO: enums
182+
183+ // TODO: sets
184+
185+ if types .IsNumber (lTyp ) || types .IsNumber (rTyp ) {
186+ if types .IsUnsigned (lTyp ) && types .IsUnsigned (rTyp ) {
187+ return types .Uint64 .(sql.ValueType ).CompareValue (ctx , lv , rv )
188+ }
189+ if types .IsSigned (lTyp ) && types .IsSigned (rTyp ) {
190+ return types .Int64 .(sql.ValueType ).CompareValue (ctx , lv , rv )
191+ }
192+ if types .IsDecimal (lTyp ) || types .IsDecimal (rTyp ) {
193+ return types .InternalDecimalType .(sql.ValueType ).CompareValue (ctx , lv , rv )
194+ }
195+ return types .Float64 .(sql.ValueType ).CompareValue (ctx , lv , rv )
196+ }
197+ return lTyp .CompareValue (ctx , lv , rv )
198+ }
199+
200+ // IsValueExpression returns whether every child supports sql.ValueExpression
201+ func (c * comparison ) IsValueExpression () bool {
202+ l , ok := c .LeftChild .(sql.ValueExpression )
203+ if ! ok {
204+ return false
205+ }
206+ r , ok := c .RightChild .(sql.ValueExpression )
207+ if ! ok {
208+ return false
209+ }
210+ return l .IsValueExpression () && r .IsValueExpression ()
211+ }
212+
161213func (c * comparison ) evalLeftAndRight (ctx * sql.Context , row sql.Row ) (interface {}, interface {}, error ) {
162214 left , err := c .Left ().Eval (ctx , row )
163215 if err != nil {
@@ -520,71 +572,6 @@ func (gt *GreaterThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error)
520572 return result == 1 , nil
521573}
522574
523- // EvalValue implements the sql.ValueExpression interface.
524- func (gt * GreaterThan ) EvalValue (ctx * sql.Context , row sql.ValueRow ) (sql.Value , error ) {
525- lv , err := gt .comparison .LeftChild .(sql.ValueExpression ).EvalValue (ctx , row )
526- if err != nil {
527- return sql.Value {}, err
528- }
529- rv , err := gt .comparison .RightChild .(sql.ValueExpression ).EvalValue (ctx , row )
530- if err != nil {
531- return sql.Value {}, err
532- }
533-
534- // TODO: move this logic into comparison
535- var cmp byte
536- if sqltypes .IsUnsigned (lv .Typ ) && sqltypes .IsUnsigned (rv .Typ ) {
537- l , cErr := types .ConvertValueToUint64 (lv )
538- if cErr != nil {
539- return sql.Value {}, cErr
540- }
541- r , cErr := types .ConvertValueToUint64 (rv )
542- if cErr != nil {
543- return sql.Value {}, cErr
544- }
545- if l > r {
546- cmp = 1
547- }
548- } else {
549- l , cErr := types .ConvertValueToInt64 (lv )
550- if cErr != nil {
551- return sql.Value {}, cErr
552- }
553- r , cErr := types .ConvertValueToInt64 (rv )
554- if cErr != nil {
555- return sql.Value {}, cErr
556- }
557- if l > r {
558- cmp = 1
559- }
560- }
561-
562- res := sql.Value {
563- Val : []byte {cmp },
564- Typ : sqltypes .Int8 ,
565- }
566- return res , nil
567- }
568-
569- // IsValueRowIter implements the ValueExpression interface.
570- func (gt * GreaterThan ) IsValueExpression () bool {
571- l , ok := gt .comparison .LeftChild .(sql.ValueExpression )
572- if ! ok {
573- return false
574- }
575- if ! l .IsValueExpression () {
576- return false
577- }
578- r , ok := gt .comparison .RightChild .(sql.ValueExpression )
579- if ! ok {
580- return false
581- }
582- if ! r .IsValueExpression () {
583- return false
584- }
585- return true
586- }
587-
588575// WithChildren implements the Expression interface.
589576func (gt * GreaterThan ) WithChildren (children ... sql.Expression ) (sql.Expression , error ) {
590577 if len (children ) != 2 {
@@ -605,6 +592,23 @@ func (gt *GreaterThan) DebugString() string {
605592 return pr .String ()
606593}
607594
595+ // EvalValue implements the sql.ValueExpression interface.
596+ func (gt * GreaterThan ) EvalValue (ctx * sql.Context , row sql.ValueRow ) (sql.Value , error ) {
597+ cmp , err := gt .CompareValue (ctx , row )
598+ if err != nil {
599+ return sql .NullValue , err
600+ }
601+ if cmp != 1 {
602+ return sql .FalseValue , nil
603+ }
604+ return sql .TrueValue , nil
605+ }
606+
607+ // IsValueExpression implements the ValueExpression interface.
608+ func (gt * GreaterThan ) IsValueExpression () bool {
609+ return gt .comparison .IsValueExpression ()
610+ }
611+
608612// LessThan is a comparison that checks an expression is less than another.
609613type LessThan struct {
610614 comparison
@@ -630,10 +634,8 @@ func (lt *LessThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
630634 if ErrNilOperand .Is (err ) {
631635 return nil , nil
632636 }
633-
634637 return nil , err
635638 }
636-
637639 return result == - 1 , nil
638640}
639641
@@ -657,6 +659,23 @@ func (lt *LessThan) DebugString() string {
657659 return pr .String ()
658660}
659661
662+ // EvalValue implements the sql.ValueExpression interface.
663+ func (lt * LessThan ) EvalValue (ctx * sql.Context , row sql.ValueRow ) (sql.Value , error ) {
664+ cmp , err := lt .CompareValue (ctx , row )
665+ if err != nil {
666+ return sql .NullValue , err
667+ }
668+ if cmp != - 1 {
669+ return sql .FalseValue , nil
670+ }
671+ return sql .TrueValue , nil
672+ }
673+
674+ // IsValueExpression implements the ValueExpression interface.
675+ func (lt * LessThan ) IsValueExpression () bool {
676+ return lt .comparison .IsValueExpression ()
677+ }
678+
660679// GreaterThanOrEqual is a comparison that checks an expression is greater or equal to
661680// another.
662681type GreaterThanOrEqual struct {
@@ -683,10 +702,8 @@ func (gte *GreaterThanOrEqual) Eval(ctx *sql.Context, row sql.Row) (interface{},
683702 if ErrNilOperand .Is (err ) {
684703 return nil , nil
685704 }
686-
687705 return nil , err
688706 }
689-
690707 return result > - 1 , nil
691708}
692709
@@ -710,6 +727,23 @@ func (gte *GreaterThanOrEqual) DebugString() string {
710727 return pr .String ()
711728}
712729
730+ // EvalValue implements the sql.ValueExpression interface.
731+ func (gte * GreaterThanOrEqual ) EvalValue (ctx * sql.Context , row sql.ValueRow ) (sql.Value , error ) {
732+ cmp , err := gte .CompareValue (ctx , row )
733+ if err != nil {
734+ return sql .NullValue , err
735+ }
736+ if cmp == - 1 {
737+ return sql .FalseValue , nil
738+ }
739+ return sql .TrueValue , nil
740+ }
741+
742+ // IsValueExpression implements the ValueExpression interface.
743+ func (gte * GreaterThanOrEqual ) IsValueExpression () bool {
744+ return gte .comparison .IsValueExpression ()
745+ }
746+
713747// LessThanOrEqual is a comparison that checks an expression is equal or lower than
714748// another.
715749type LessThanOrEqual struct {
@@ -763,6 +797,23 @@ func (lte *LessThanOrEqual) DebugString() string {
763797 return pr .String ()
764798}
765799
800+ // EvalValue implements the sql.ValueExpression interface.
801+ func (lte * LessThanOrEqual ) EvalValue (ctx * sql.Context , row sql.ValueRow ) (sql.Value , error ) {
802+ cmp , err := lte .CompareValue (ctx , row )
803+ if err != nil {
804+ return sql .NullValue , err
805+ }
806+ if cmp == 1 {
807+ return sql .FalseValue , nil
808+ }
809+ return sql .TrueValue , nil
810+ }
811+
812+ // IsValueExpression implements the ValueExpression interface.
813+ func (lte * LessThanOrEqual ) IsValueExpression () bool {
814+ return lte .comparison .IsValueExpression ()
815+ }
816+
766817var (
767818 // ErrUnsupportedInOperand is returned when there is an invalid righthand
768819 // operand in an IN operator.
0 commit comments