diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 3e2fcc1631..2567274cf5 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -9355,7 +9355,7 @@ where }, }, { - Name: "enum conversion to strings", + Name: "enum conversions", Dialect: "mysql", SetUpScript: []string{ "create table t (e enum('abc', 'defg', 'hijkl'));", @@ -9491,6 +9491,38 @@ where {"abc"}, }, }, + { + Query: "select e, cast(e as unsigned) from t order by e;", + Expected: []sql.Row{ + {"abc", uint64(1)}, + {"defg", uint64(2)}, + {"hijkl", uint64(3)}, + }, + }, + { + Query: "select e, cast(e as decimal) from t order by e;", + Expected: []sql.Row{ + {"abc", "1"}, + {"defg", "2"}, + {"hijkl", "3"}, + }, + }, + { + Query: "select e, cast(e as float) from t order by e;", + Expected: []sql.Row{ + {"abc", float32(1)}, + {"defg", float32(2)}, + {"hijkl", float32(3)}, + }, + }, + { + Query: "select e, cast(e as double) from t order by e;", + Expected: []sql.Row{ + {"abc", float64(1)}, + {"defg", float64(2)}, + {"hijkl", float64(3)}, + }, + }, }, }, { @@ -9953,7 +9985,7 @@ where }, }, { - Name: "set conversion to strings", + Name: "set conversions", Dialect: "mysql", SetUpScript: []string{ "create table t (s set('abc', 'defg', 'hijkl'));", @@ -10077,25 +10109,59 @@ where }, }, { - // https://github.com/dolthub/dolt/issues/9511 - Skip: true, Query: "select s, cast(s as char) from t order by s;", Expected: []sql.Row{ {"abc", "abc"}, + {"defg", "defg"}, {"abc,defg", "abc,defg"}, {"abc,defg,hijkl", "abc,defg,hijkl"}, }, }, { - // https://github.com/dolthub/dolt/issues/9511 - Skip: true, Query: "select s, cast(s as binary) from t order by s;", Expected: []sql.Row{ {"abc", []uint8("abc")}, + {"defg", []uint8("defg")}, {"abc,defg", []uint8("abc,defg")}, {"abc,defg,hijkl", []uint8("abc,defg,hijkl")}, }, }, + { + Query: "select s, cast(s as unsigned) from t order by s;", + Expected: []sql.Row{ + {"abc", uint64(1)}, + {"defg", uint64(2)}, + {"abc,defg", uint64(3)}, + {"abc,defg,hijkl", uint64(7)}, + }, + }, + { + Query: "select s, cast(s as decimal) from t order by s;", + Expected: []sql.Row{ + {"abc", "1"}, + {"defg", "2"}, + {"abc,defg", "3"}, + {"abc,defg,hijkl", "7"}, + }, + }, + { + Query: "select s, cast(s as float) from t order by s;", + Expected: []sql.Row{ + {"abc", float32(1)}, + {"defg", float32(2)}, + {"abc,defg", float32(3)}, + {"abc,defg,hijkl", float32(7)}, + }, + }, + { + Query: "select s, cast(s as double) from t order by s;", + Expected: []sql.Row{ + {"abc", float64(1)}, + {"defg", float64(2)}, + {"abc,defg", float64(3)}, + {"abc,defg,hijkl", float64(7)}, + }, + }, }, }, { diff --git a/sql/expression/convert.go b/sql/expression/convert.go index d115de49ca..d60df93c1e 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -278,17 +278,22 @@ func (c *Convert) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return casted, nil } -// convertValue only returns an error if converting to JSON, Date, and Datetime; -// the zero value is returned for float types. Nil is returned in all other cases. +// convertValue converts a value from its current type to the specified target type for CAST/CONVERT operations. +// It handles type-specific conversion logic and applies length/scale constraints where applicable. // If |typeLength| and |typeScale| are 0, they are ignored, otherwise they are used as constraints on the // converted type where applicable (e.g. Char conversion supports only |typeLength|, Decimal conversion supports // |typeLength| and |typeScale|). +// Only returns an error if converting to JSON, Date, and Datetime; the zero value is returned for float types. +// Nil is returned in all other cases. func convertValue(ctx *sql.Context, val interface{}, castTo string, originType sql.Type, typeLength, typeScale int) (interface{}, error) { if val == nil { return nil, nil } switch strings.ToLower(castTo) { case ConvertToBinary: + if types.IsSet(originType) || types.IsEnum(originType) { + val, _ = types.TypeAwareConversion(ctx, val, originType, types.LongText) + } b, _, err := types.LongBlob.Convert(ctx, val) if err != nil { return nil, nil @@ -307,6 +312,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s } return truncateConvertedValue(b, typeLength) case ConvertToChar, ConvertToNChar: + val, _ = types.TypeAwareConversion(ctx, val, originType, types.LongText) s, _, err := types.LongText.Convert(ctx, val) if err != nil { return nil, nil