From c857474ff7d25d48b27baff67e521f2a81bd958d Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 14 Oct 2024 16:57:39 -0700 Subject: [PATCH 1/2] compare and convert system types properly --- enginetest/queries/script_queries.go | 16 ++++ sql/analyzer/resolve_create_select.go | 4 + sql/expression/function/coalesce.go | 106 +++++++++++++---------- sql/expression/function/coalesce_test.go | 89 ++++++++++++++++--- sql/types/typecheck.go | 47 ++++++---- sql/types/typecheck_test.go | 60 +++++++++++++ 6 files changed, 249 insertions(+), 73 deletions(-) diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index b6dc976315..4d35189e59 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -7499,6 +7499,22 @@ where }, }, }, + { + Name: "coalesce with system types", + SetUpScript: []string{ + "create table t as select @@admin_port as port1, @@port as port2, COALESCE(@@admin_port, @@port) as\n port3;", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "describe t;", + Expected: []sql.Row{ + {"port1", "bigint", "NO", "", nil, ""}, + {"port2", "bigint", "NO", "", nil, ""}, + {"port3", "bigint", "NO", "", nil, ""}, + }, + }, + }, + }, } var SpatialScriptTests = []ScriptTest{ diff --git a/sql/analyzer/resolve_create_select.go b/sql/analyzer/resolve_create_select.go index f0bbcec056..34960ca178 100644 --- a/sql/analyzer/resolve_create_select.go +++ b/sql/analyzer/resolve_create_select.go @@ -31,6 +31,10 @@ func resolveCreateSelect(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan. for i, col := range mergedSchema { tempCol := *col tempCol.Source = ct.Name() + // replace system variable types with their underlying types + if sysType, isSysTyp := tempCol.Type.(sql.SystemVariableType); isSysTyp { + tempCol.Type = sysType.UnderlyingType() + } newSch[i] = &tempCol } diff --git a/sql/expression/function/coalesce.go b/sql/expression/function/coalesce.go index 99734f4263..9d60e11a02 100644 --- a/sql/expression/function/coalesce.go +++ b/sql/expression/function/coalesce.go @@ -16,11 +16,11 @@ package function import ( "fmt" - "strings" + "github.com/dolthub/go-mysql-server/sql/expression" +"strings" "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/expression" - "github.com/dolthub/go-mysql-server/sql/types" + "github.com/dolthub/go-mysql-server/sql/types" ) // Coalesce returns the first non-NULL value in the list, or NULL if there are no non-NULL values. @@ -53,61 +53,73 @@ func (c *Coalesce) Description() string { // Type implements the sql.Expression interface. // The return type of Type() is the aggregated type of the argument types. func (c *Coalesce) Type() sql.Type { - typ := types.Null - for _, arg := range c.args { + retType := types.Null + for i, arg := range c.args { if arg == nil { continue } - t := arg.Type() + argType := arg.Type() + if sysVarType, ok := argType.(sql.SystemVariableType); ok { + argType = sysVarType.UnderlyingType() + } + if i == 0 { + retType = argType + continue + } + if argType == nil || argType == types.Null { + continue + } + if retType.Equals(argType) { + continue + } + // special case for signed and unsigned integers - if (types.IsSigned(typ) && types.IsUnsigned(t)) || (types.IsUnsigned(typ) && types.IsSigned(t)) { - typ = types.MustCreateDecimalType(20, 0) + if (types.IsSigned(retType) && types.IsUnsigned(argType)) || (types.IsUnsigned(retType) && types.IsSigned(argType)) { + retType = types.MustCreateDecimalType(20, 0) continue } - if t != nil && t != types.Null { - convType := expression.GetConvertToType(typ, t) - switch convType { - case expression.ConvertToChar: - // special case for float64s - if (t == types.Float64 || typ == types.Float64) && !types.IsText(t) && !types.IsText(typ) { - typ = types.Float64 - continue - } - // Can't get any larger than this - return types.LongText - case expression.ConvertToDecimal: - if typ == types.Float64 || t == types.Float64 { - typ = types.Float64 - } else if types.IsDecimal(t) { - typ = t - } else if !types.IsDecimal(typ) { - typ = types.MustCreateDecimalType(10, 0) - } - case expression.ConvertToUnsigned: - if typ == types.Uint64 || t == types.Uint64 { - typ = types.Uint64 - } else { - typ = types.Uint32 - } - case expression.ConvertToSigned: - if typ == types.Int64 || t == types.Int64 { - typ = types.Int64 - } else { - typ = types.Int32 - } - case expression.ConvertToFloat: - if typ == types.Float64 || t == types.Float64 { - typ = types.Float64 - } else { - typ = types.Float32 - } - default: + convType := expression.GetConvertToType(retType, argType) + switch convType { + case expression.ConvertToChar: + // special case for float64s + if (argType == types.Float64 || retType == types.Float64) && !types.IsText(argType) && !types.IsText(retType) { + retType = types.Float64 + continue + } + // Can't get any larger than this + return types.LongText + case expression.ConvertToDecimal: + if retType == types.Float64 || argType == types.Float64 { + retType = types.Float64 + } else if types.IsDecimal(argType) { + retType = argType + } else if !types.IsDecimal(retType) { + retType = types.MustCreateDecimalType(10, 0) + } + case expression.ConvertToUnsigned: + if retType == types.Uint64 || argType == types.Uint64 { + retType = types.Uint64 + } else { + retType = types.Uint32 + } + case expression.ConvertToSigned: + if retType == types.Int64 || argType == types.Int64 { + retType = types.Int64 + } else { + retType = types.Int32 + } + case expression.ConvertToFloat: + if retType == types.Float64 || argType == types.Float64 { + retType = types.Float64 + } else { + retType = types.Float32 } + default: } } - return typ + return retType } // CollationCoercibility implements the interface sql.CollationCoercible. diff --git a/sql/expression/function/coalesce_test.go b/sql/expression/function/coalesce_test.go index 76ac8105aa..bccdeedb80 100644 --- a/sql/expression/function/coalesce_test.go +++ b/sql/expression/function/coalesce_test.go @@ -15,10 +15,10 @@ package function import ( - "testing" - "github.com/shopspring/decimal" - "github.com/stretchr/testify/require" +"testing" + + "github.com/stretchr/testify/require" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" @@ -152,17 +152,86 @@ func TestCoalesce(t *testing.T) { typ: types.Float64, nullable: false, }, + { + name: "coalesce(sysInt, sysInt)", + input: []sql.Expression{ + expression.NewLiteral(1, types.NewSystemIntType("int1", 0, 10, false)), + expression.NewLiteral(2, types.NewSystemIntType("int2", 0, 10, false)), + }, + expected: 1, + typ: types.Int64, + nullable: false, + }, + { + name: "coalesce(sysInt, sysUint)", + input: []sql.Expression{ + expression.NewLiteral(1, types.NewSystemIntType("int1", 0, 10, false)), + expression.NewLiteral(2, types.NewSystemUintType("int2", 0, 10)), + }, + expected: 1, + typ: types.MustCreateDecimalType(20, 0), + nullable: false, + }, + { + name: "coalesce(sysUint, sysUint)", + input: []sql.Expression{ + expression.NewLiteral(1, types.NewSystemUintType("int1", 0, 10)), + expression.NewLiteral(2, types.NewSystemUintType("int2", 0, 10)), + }, + expected: 1, + typ: types.Uint64, + nullable: false, + }, + { + name: "coalesce(sysDouble, sysDouble)", + input: []sql.Expression{ + expression.NewLiteral(1.0, types.NewSystemDoubleType("dbl1", 0.0, 10.0)), + expression.NewLiteral(2.0, types.NewSystemDoubleType("dbl2", 0.0, 10.0)), + }, + expected: 1.0, + typ: types.Float64, + nullable: false, + }, + { + name: "coalesce(sysText)", + input: []sql.Expression{ + expression.NewLiteral("abc", types.NewSystemStringType("str1")), + }, + expected: "abc", + typ: types.LongText, + nullable: false, + }, + { + name: "coalesce(sysEnum)", + input: []sql.Expression{ + expression.NewLiteral("abc", types.NewSystemEnumType("str1")), + }, + expected: "abc", + typ: types.EnumType{}, + nullable: false, + }, + { + name: "coalesce(sysSet)", + input: []sql.Expression{ + expression.NewLiteral("abc", types.NewSystemSetType("str1", "abc")), + }, + expected: "abc", + typ: types.MustCreateSetType([]string{"abc"}, sql.Collation_Default), + nullable: false, + }, } for _, tt := range testCases { - c, err := NewCoalesce(tt.input...) - require.NoError(t, err) + t.Run(tt.name, func(t *testing.T) { + c, err := NewCoalesce(tt.input...) + require.NoError(t, err) - require.Equal(t, tt.typ, c.Type()) - require.Equal(t, tt.nullable, c.IsNullable()) - v, err := c.Eval(sql.NewEmptyContext(), nil) - require.NoError(t, err) - require.Equal(t, tt.expected, v) + require.Equal(t, tt.typ, c.Type()) + require.Equal(t, tt.nullable, c.IsNullable()) + v, err := c.Eval(sql.NewEmptyContext(), nil) + require.NoError(t, err) + require.Equal(t, tt.expected, v) + }) } } diff --git a/sql/types/typecheck.go b/sql/types/typecheck.go index 8ce1a65273..b2cb7f818d 100644 --- a/sql/types/typecheck.go +++ b/sql/types/typecheck.go @@ -91,8 +91,10 @@ func IsNull(ex sql.Expression) bool { // IsNumber checks if t is a number type func IsNumber(t sql.Type) bool { - switch t.(type) { - case NumberTypeImpl_, DecimalType_, BitType_, YearType_, SystemBoolType: + switch typ := t.(type) { + case sql.SystemVariableType: + return IsNumber(typ.UnderlyingType()) + case NumberTypeImpl_, DecimalType_, BitType_, YearType_: return true default: return false @@ -101,23 +103,25 @@ func IsNumber(t sql.Type) bool { // IsSigned checks if t is a signed type. func IsSigned(t sql.Type) bool { - // systemBoolType is Int8 - if _, ok := t.(SystemBoolType); ok { - return true + if svt, ok := t.(sql.SystemVariableType); ok { + t = svt.UnderlyingType() } return t == Int8 || t == Int16 || t == Int24 || t == Int32 || t == Int64 || t == Boolean } // IsText checks if t is a CHAR, VARCHAR, TEXT, BINARY, VARBINARY, or BLOB (including TEXT and BLOB variants). func IsText(t sql.Type) bool { - if _, ok := t.(StringType); ok { - return ok - } - if extendedType, ok := t.(ExtendedType); ok { - _, isString := extendedType.Zero().(string) + switch typ := t.(type) { + case sql.SystemVariableType: + return IsText(typ.UnderlyingType()) + case StringType: + return true + case ExtendedType: + _, isString := typ.Zero().(string) return isString + default: + return false } - return false } // IsTextBlob checks if t is one of the TEXTs or BLOBs. @@ -178,14 +182,26 @@ func IsTimestampType(t sql.Type) bool { // IsEnum checks if t is a enum func IsEnum(t sql.Type) bool { - _, ok := t.(EnumType) - return ok + switch typ := t.(type) { + case sql.SystemVariableType: + return IsEnum(typ.UnderlyingType()) + case EnumType: + return true + default: + return false + } } // IsSet checks if t is a set func IsSet(t sql.Type) bool { - _, ok := t.(SetType) - return ok + switch typ := t.(type) { + case sql.SystemVariableType: + return IsSet(typ.UnderlyingType()) + case SetType: + return true + default: + return false + } } // IsTuple checks if t is a tuple type. @@ -201,7 +217,6 @@ func IsUnsigned(t sql.Type) bool { if svt, ok := t.(sql.SystemVariableType); ok { t = svt.UnderlyingType() } - return t == Uint8 || t == Uint16 || t == Uint24 || t == Uint32 || t == Uint64 } diff --git a/sql/types/typecheck_test.go b/sql/types/typecheck_test.go index a67c16eb33..783685723c 100644 --- a/sql/types/typecheck_test.go +++ b/sql/types/typecheck_test.go @@ -36,3 +36,63 @@ func TestIsJSON(t *testing.T) { assert.False(t, IsJSON(NumberTypeImpl_{})) assert.False(t, IsJSON(StringType{})) } + +func TestSystemTypesIsNumber(t *testing.T) { + assert.True(t, IsNumber(SystemBoolType{})) + assert.True(t, IsNumber(systemIntType{})) + assert.True(t, IsNumber(systemUintType{})) + assert.True(t, IsNumber(systemDoubleType{})) + assert.False(t, IsNumber(systemEnumType{})) + assert.False(t, IsNumber(systemSetType{})) + assert.False(t, IsNumber(systemStringType{})) +} + +func TestSystemTypesIsSigned(t *testing.T) { + assert.True(t, IsSigned(SystemBoolType{})) + assert.True(t, IsSigned(systemIntType{})) + assert.False(t, IsSigned(systemUintType{})) + assert.False(t, IsSigned(systemDoubleType{})) + assert.False(t, IsSigned(systemEnumType{})) + assert.False(t, IsSigned(systemSetType{})) + assert.False(t, IsSigned(systemStringType{})) +} + +func TestSystemTypesIsUnsigned(t *testing.T) { + assert.False(t, IsUnsigned(SystemBoolType{})) + assert.False(t, IsUnsigned(systemIntType{})) + assert.True(t, IsUnsigned(systemUintType{})) + assert.False(t, IsUnsigned(systemDoubleType{})) + assert.False(t, IsUnsigned(systemEnumType{})) + assert.False(t, IsUnsigned(systemSetType{})) + assert.False(t, IsUnsigned(systemStringType{})) +} + +func TestSystemTypesIsText(t *testing.T) { + assert.False(t, IsText(SystemBoolType{})) + assert.False(t, IsText(systemIntType{})) + assert.False(t, IsText(systemUintType{})) + assert.False(t, IsText(systemDoubleType{})) + assert.False(t, IsText(systemEnumType{})) + assert.False(t, IsText(systemSetType{})) + assert.True(t, IsText(systemStringType{})) +} + +func TestSystemTypesIsEnum(t *testing.T) { + assert.False(t, IsEnum(SystemBoolType{})) + assert.False(t, IsEnum(systemIntType{})) + assert.False(t, IsEnum(systemUintType{})) + assert.False(t, IsEnum(systemDoubleType{})) + assert.True(t, IsEnum(systemEnumType{})) + assert.False(t, IsEnum(systemSetType{})) + assert.False(t, IsEnum(systemStringType{})) +} + +func TestSystemTypesIsSet(t *testing.T) { + assert.False(t, IsSet(SystemBoolType{})) + assert.False(t, IsSet(systemIntType{})) + assert.False(t, IsSet(systemUintType{})) + assert.False(t, IsSet(systemDoubleType{})) + assert.False(t, IsSet(systemEnumType{})) + assert.True(t, IsSet(NewSystemSetType("", ""))) + assert.False(t, IsSet(systemStringType{})) +} \ No newline at end of file From 79b6f54b7ec441f4e8f1a8c13fd343299078a3f0 Mon Sep 17 00:00:00 2001 From: jycor Date: Tue, 15 Oct 2024 00:06:44 +0000 Subject: [PATCH 2/2] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/expression/function/coalesce.go | 5 +++-- sql/expression/function/coalesce_test.go | 6 +++--- sql/types/typecheck_test.go | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/expression/function/coalesce.go b/sql/expression/function/coalesce.go index 9d60e11a02..11fc9e04ec 100644 --- a/sql/expression/function/coalesce.go +++ b/sql/expression/function/coalesce.go @@ -16,11 +16,12 @@ package function import ( "fmt" + "strings" + "github.com/dolthub/go-mysql-server/sql/expression" -"strings" "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/types" + "github.com/dolthub/go-mysql-server/sql/types" ) // Coalesce returns the first non-NULL value in the list, or NULL if there are no non-NULL values. diff --git a/sql/expression/function/coalesce_test.go b/sql/expression/function/coalesce_test.go index bccdeedb80..dbb9e1894d 100644 --- a/sql/expression/function/coalesce_test.go +++ b/sql/expression/function/coalesce_test.go @@ -15,10 +15,10 @@ package function import ( - "github.com/shopspring/decimal" -"testing" + "testing" - "github.com/stretchr/testify/require" + "github.com/shopspring/decimal" + "github.com/stretchr/testify/require" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" diff --git a/sql/types/typecheck_test.go b/sql/types/typecheck_test.go index 783685723c..2a88cb82aa 100644 --- a/sql/types/typecheck_test.go +++ b/sql/types/typecheck_test.go @@ -95,4 +95,4 @@ func TestSystemTypesIsSet(t *testing.T) { assert.False(t, IsSet(systemEnumType{})) assert.True(t, IsSet(NewSystemSetType("", ""))) assert.False(t, IsSet(systemStringType{})) -} \ No newline at end of file +}