@@ -17,8 +17,6 @@ package expression
1717import (
1818 "fmt"
1919
20- "github.com/dolthub/vitess/go/mysql"
21-
2220 "github.com/dolthub/go-mysql-server/sql"
2321 "github.com/dolthub/go-mysql-server/sql/hash"
2422 "github.com/dolthub/go-mysql-server/sql/types"
@@ -130,58 +128,59 @@ func validateAndEvalRightTuple(ctx *sql.Context, lType sql.Type, right Tuple, ro
130128
131129// Eval implements the Expression interface.
132130func (in * InTuple ) Eval (ctx * sql.Context , row sql.Row ) (interface {}, error ) {
133- leftVal , err := in .Left ().Eval (ctx , row )
131+ lVal , err := in .Left ().Eval (ctx , row )
134132 if err != nil {
135133 return nil , err
136134 }
137- if leftVal == nil {
135+ if lVal == nil {
138136 return nil , nil
139137 }
140138
139+ lType := in .Left ().Type ()
140+ lColCount := types .NumColumns (lType )
141+ lLit := NewLiteral (lVal , lType )
142+
141143 right , isTuple := in .Right ().(Tuple )
142144 if ! isTuple {
143145 return nil , ErrUnsupportedInOperand .New (right )
144146 }
145147
146- lType := in .Left ().Type ()
147- rVals , cmpType , rHasNull , err := validateAndEvalRightTuple (ctx , lType , right , row )
148- if err != nil {
149- return nil , err
150- }
148+ var rHasNull bool
149+ for _ , el := range right {
150+ rType := el .Type ()
151+ if rType == types .Null {
152+ rHasNull = true
153+ continue
154+ }
151155
152- lv , _ , lErr := cmpType .Convert (ctx , leftVal )
153- if lErr != nil {
154- if sql .ErrTruncatedIncorrect .Is (lErr ) {
155- ctx .Warn (mysql .ERTruncatedWrongValue , "%s" , lErr .Error ())
156- } else {
157- lv = cmpType .Zero ()
156+ // Nested tuples must have the same number of columns
157+ rColCount := types .NumColumns (rType )
158+ if rColCount != lColCount {
159+ return nil , sql .ErrInvalidOperandColumns .New (lColCount , rColCount )
158160 }
159- }
160161
161- for _ , rVal := range rVals {
162+ rVal , rErr := el .Eval (ctx , row )
163+ if rErr != nil {
164+ return nil , rErr
165+ }
162166 if rVal == nil {
167+ rHasNull = true
163168 continue
164169 }
165- rv , _ , rErr := cmpType .Convert (ctx , rVal )
166- if rErr != nil {
167- if sql .ErrTruncatedIncorrect .Is (rErr ) {
168- ctx .Warn (mysql .ERTruncatedWrongValue , "%s" , rErr .Error ())
169- } else {
170- rv = cmpType .Zero ()
171- }
172- }
173- cmp , cErr := cmpType .Compare (ctx , lv , rv )
170+
171+ cmpExpr := newComparison (lLit , NewLiteral (rVal , rType ))
172+ res , cErr := cmpExpr .Compare (ctx , nil )
174173 if cErr != nil {
175- continue
174+ return nil , cErr
176175 }
177- if cmp == 0 {
176+ if res == 0 {
178177 return true , nil
179178 }
180179 }
180+
181181 if rHasNull {
182182 return nil , nil
183183 }
184-
185184 return false , nil
186185}
187186
@@ -258,16 +257,45 @@ func newInMap(ctx *sql.Context, lType sql.Type, right Tuple) (map[uint64]struct{
258257 if lType == types .Null {
259258 return nil , nil , true , nil
260259 }
260+ lColCount := types .NumColumns (lType )
261261 if len (right ) == 0 {
262262 return nil , nil , false , nil
263263 }
264- rVals , cmpType , rHasNull , err := validateAndEvalRightTuple (ctx , lType , right , nil )
265- if err != nil {
266- return nil , nil , false , err
264+ // only non-nil elements are included
265+ rVals := make ([]any , 0 , len (right ))
266+ var rHasNull bool
267+ for _ , el := range right {
268+ rType := el .Type ()
269+ rColCount := types .NumColumns (rType )
270+ if lColCount != rColCount {
271+ return nil , nil , false , sql .ErrInvalidOperandColumns .New (lColCount , rColCount )
272+ }
273+ if rType == types .Null {
274+ rHasNull = true
275+ continue
276+ }
277+ rVal , err := el .Eval (ctx , nil )
278+ if err != nil {
279+ return nil , nil , false , err
280+ }
281+ if rVal == nil {
282+ rHasNull = true
283+ continue
284+ }
285+ rVals = append (rVals , rVal )
267286 }
268- elements := make (map [uint64 ]struct {})
269- for _ , v := range rVals {
270- key , hErr := hash .HashOfSimple (ctx , v , cmpType )
287+
288+ var cmpType sql.Type
289+ if types .IsEnum (lType ) || types .IsSet (lType ) {
290+ cmpType = lType
291+ } else {
292+ // If we've made it this far, we are guaranteed that the right Tuple has a consistent set of types
293+ // (all numeric, string, or time), so it is enough to just compare against the first element of the right Tuple
294+ cmpType = types .GetCompareType (lType , right [0 ].Type ())
295+ }
296+ elements := map [uint64 ]struct {}{}
297+ for _ , rVal := range rVals {
298+ key , hErr := hash .HashOfSimple (ctx , rVal , cmpType )
271299 if hErr != nil {
272300 return nil , nil , false , hErr
273301 }
0 commit comments