diff --git a/server/analyzer/init.go b/server/analyzer/init.go index 1047bb2457..95c99cc05c 100644 --- a/server/analyzer/init.go +++ b/server/analyzer/init.go @@ -49,6 +49,7 @@ const ( ruleId_ValidateCreateSchema // validateCreateSchema ruleId_ResolveAlterColumn // resolveAlterColumn ruleId_ValidateCreateFunction + ruleId_ResolveValuesTypes // resolveValuesTypes ) // Init adds additional rules to the analyzer to handle Doltgres-specific functionality. @@ -56,6 +57,7 @@ func Init() { analyzer.AlwaysBeforeDefault = append(analyzer.AlwaysBeforeDefault, analyzer.Rule{Id: ruleId_ResolveType, Apply: ResolveType}, analyzer.Rule{Id: ruleId_TypeSanitizer, Apply: TypeSanitizer}, + analyzer.Rule{Id: ruleId_ResolveValuesTypes, Apply: ResolveValuesTypes}, analyzer.Rule{Id: ruleId_GenerateForeignKeyName, Apply: generateForeignKeyName}, analyzer.Rule{Id: ruleId_AddDomainConstraints, Apply: AddDomainConstraints}, analyzer.Rule{Id: ruleId_ValidateColumnDefaults, Apply: ValidateColumnDefaults}, diff --git a/server/analyzer/resolve_values_types.go b/server/analyzer/resolve_values_types.go new file mode 100644 index 0000000000..d413e9843a --- /dev/null +++ b/server/analyzer/resolve_values_types.go @@ -0,0 +1,331 @@ +// Copyright 2026 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package analyzer + +import ( + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/analyzer" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql/transform" + + pgexprs "github.com/dolthub/doltgresql/server/expression" + "github.com/dolthub/doltgresql/server/functions/framework" + pgtypes "github.com/dolthub/doltgresql/server/types" +) + +// ResolveValuesTypes determines the common type for each column in a VALUES clause +// by examining all rows, following PostgreSQL's type resolution rules. +// This ensures VALUES(1),(2.01),(3) correctly infers numeric type, not integer. +func ResolveValuesTypes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, scope *plan.Scope, selector analyzer.RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) { + // Track which VDTs we transform so we can update parent nodes + transformedVDTs := make(map[*plan.ValueDerivedTable]sql.Schema) + + // First pass: transform VDTs and record their new schemas + node, same1, err := transform.NodeWithOpaque(node, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { + newNode, same, err := transformValuesNode(n) + if err != nil { + return nil, same, err + } + if !same { + if vdt, ok := newNode.(*plan.ValueDerivedTable); ok { + transformedVDTs[vdt] = vdt.Schema() + } + } + return newNode, same, err + }) + if err != nil { + return nil, transform.SameTree, err + } + + // Second pass: update GetField types in parent nodes that reference transformed VDTs + if len(transformedVDTs) > 0 { + node, _, err = transform.NodeWithOpaque(node, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { + return updateGetFieldTypes(n, transformedVDTs) + }) + if err != nil { + return nil, transform.SameTree, err + } + } + + return node, same1, nil +} + +// getSourceSchema traverses through wrapper nodes (GroupBy, Filter, etc.) to find +// the actual source schema from a VDT or other data source. This is needed because +// nodes like GroupBy produce a different output schema than their input schema. +func getSourceSchema(n sql.Node) sql.Schema { + switch node := n.(type) { + case *plan.GroupBy: + // GroupBy's Schema() returns aggregate output, but we need the source schema + return getSourceSchema(node.Child) + case *plan.Filter: + return getSourceSchema(node.Child) + case *plan.Sort: + return getSourceSchema(node.Child) + case *plan.Limit: + return getSourceSchema(node.Child) + case *plan.Offset: + return getSourceSchema(node.Child) + case *plan.Distinct: + return getSourceSchema(node.Child) + case *plan.SubqueryAlias: + // SubqueryAlias wraps a VDT - get the child's schema + return node.Child.Schema() + case *plan.ValueDerivedTable: + return node.Schema() + default: + // For other nodes, return their schema directly + return n.Schema() + } +} + +// updateGetFieldTypes updates GetField expressions that reference transformed VDT columns +func updateGetFieldTypes(n sql.Node, transformedVDTs map[*plan.ValueDerivedTable]sql.Schema) (sql.Node, transform.TreeIdentity, error) { + // Only handle nodes that have expressions (like Project) + exprNode, ok := n.(sql.Expressioner) + if !ok { + return n, transform.SameTree, nil + } + + // Get the source schema by traversing through wrapper nodes like GroupBy + // This ensures we get the VDT's schema, not the aggregate output schema + var childSchema sql.Schema + switch node := n.(type) { + case *plan.Project: + childSchema = getSourceSchema(node.Child) + case *plan.SubqueryAlias: + childSchema = node.Child.Schema() + default: + return n, transform.SameTree, nil + } + + if childSchema == nil { + return n, transform.SameTree, nil + } + + // Transform expressions to update GetField types (recursively for nested expressions) + exprs := exprNode.Expressions() + newExprs := make([]sql.Expression, len(exprs)) + changed := false + + for i, expr := range exprs { + newExpr, exprChanged, err := updateGetFieldExprRecursive(expr, childSchema) + if err != nil { + return nil, transform.SameTree, err + } + newExprs[i] = newExpr + if exprChanged { + changed = true + } + } + + if !changed { + return n, transform.SameTree, nil + } + + newNode, err := exprNode.WithExpressions(newExprs...) + if err != nil { + return nil, transform.SameTree, err + } + return newNode.(sql.Node), transform.NewTree, nil +} + +// updateGetFieldExprRecursive recursively updates GetField expressions in the expression tree +func updateGetFieldExprRecursive(expr sql.Expression, childSchema sql.Schema) (sql.Expression, bool, error) { + // First try to update if this is a GetField + if _, ok := expr.(*expression.GetField); ok { + return updateGetFieldExpr(expr, childSchema) + } + + // Recursively process children + children := expr.Children() + if len(children) == 0 { + return expr, false, nil + } + + newChildren := make([]sql.Expression, len(children)) + changed := false + for i, child := range children { + newChild, childChanged, err := updateGetFieldExprRecursive(child, childSchema) + if err != nil { + return nil, false, err + } + newChildren[i] = newChild + if childChanged { + changed = true + } + } + + if !changed { + return expr, false, nil + } + + newExpr, err := expr.WithChildren(newChildren...) + if err != nil { + return nil, false, err + } + return newExpr, true, nil +} + +// updateGetFieldExpr updates a GetField expression to use the correct type from the child schema +func updateGetFieldExpr(expr sql.Expression, childSchema sql.Schema) (sql.Expression, bool, error) { + gf, ok := expr.(*expression.GetField) + if !ok { + return expr, false, nil + } + + idx := gf.Index() + // GetField indices are 1-based in GMS planbuilder, so subtract 1 for schema access + schemaIdx := idx - 1 + if schemaIdx < 0 || schemaIdx >= len(childSchema) { + return expr, false, nil + } + + newType := childSchema[schemaIdx].Type + if gf.Type() == newType { + return expr, false, nil + } + + // Create a new GetField with the updated type + newGf := expression.NewGetFieldWithTable( + idx, + int(gf.TableId()), + newType, + gf.Database(), + gf.Table(), + gf.Name(), + gf.IsNullable(), + ) + return newGf, true, nil +} + +// transformValuesNode transforms a VALUES or ValueDerivedTable node to use common types +func transformValuesNode(n sql.Node) (sql.Node, transform.TreeIdentity, error) { + // Handle both ValueDerivedTable and Values nodes + var values *plan.Values + var vdt *plan.ValueDerivedTable + var isVDT bool + + switch v := n.(type) { + case *plan.ValueDerivedTable: + vdt = v + values = v.Values + isVDT = true + case *plan.Values: + values = v + isVDT = false + default: + return n, transform.SameTree, nil + } + + // Skip if no rows or single row (nothing to unify) + if len(values.ExpressionTuples) <= 1 { + return n, transform.SameTree, nil + } + + numCols := len(values.ExpressionTuples[0]) + if numCols == 0 { + return n, transform.SameTree, nil + } + + // Collect types for each column across all rows + columnTypes := make([][]*pgtypes.DoltgresType, numCols) + for colIdx := 0; colIdx < numCols; colIdx++ { + columnTypes[colIdx] = make([]*pgtypes.DoltgresType, len(values.ExpressionTuples)) + for rowIdx, row := range values.ExpressionTuples { + exprType := row[colIdx].Type() + if exprType == nil { + columnTypes[colIdx][rowIdx] = pgtypes.Unknown + } else if pgType, ok := exprType.(*pgtypes.DoltgresType); ok { + columnTypes[colIdx][rowIdx] = pgType + } else { + // Non-DoltgresType encountered - should have been sanitized + // Return unchanged and let TypeSanitizer handle it + return n, transform.SameTree, nil + } + } + } + + // Find common type for each column + commonTypes := make([]*pgtypes.DoltgresType, numCols) + for colIdx := 0; colIdx < numCols; colIdx++ { + commonType, err := framework.FindCommonType(columnTypes[colIdx]) + if err != nil { + return nil, transform.NewTree, err + } + commonTypes[colIdx] = commonType + } + + // Check if any changes are needed + needsChange := false + for colIdx := 0; colIdx < numCols; colIdx++ { + for rowIdx := 0; rowIdx < len(values.ExpressionTuples); rowIdx++ { + if !columnTypes[colIdx][rowIdx].Equals(commonTypes[colIdx]) { + needsChange = true + break + } + } + if needsChange { + break + } + } + + if !needsChange { + return n, transform.SameTree, nil + } + + // Create new expression tuples with implicit casts where needed + newTuples := make([][]sql.Expression, len(values.ExpressionTuples)) + for rowIdx, row := range values.ExpressionTuples { + newTuples[rowIdx] = make([]sql.Expression, numCols) + for colIdx, expr := range row { + fromType := columnTypes[colIdx][rowIdx] + toType := commonTypes[colIdx] + if fromType.Equals(toType) { + newTuples[rowIdx][colIdx] = expr + } else if fromType.ID == pgtypes.Unknown.ID { + // Unknown type can be coerced to any type without explicit cast + // Use UnknownCoercion to report the target type while passing through values + newTuples[rowIdx][colIdx] = pgexprs.NewUnknownCoercion(expr, toType) + } else { + newTuples[rowIdx][colIdx] = pgexprs.NewImplicitCast(expr, fromType, toType) + } + } + } + + // Flatten the new tuples into a single expression slice for WithExpressions + var flatExprs []sql.Expression + for _, row := range newTuples { + flatExprs = append(flatExprs, row...) + } + + if isVDT { + // Use WithExpressions to preserve all VDT fields (name, columns, id, cols) + // while updating the expressions and recalculating the schema + newNode, err := vdt.WithExpressions(flatExprs...) + if err != nil { + return nil, transform.NewTree, err + } + return newNode, transform.NewTree, nil + } + + // For standalone Values node, use WithExpressions as well + newNode, err := values.WithExpressions(flatExprs...) + if err != nil { + return nil, transform.NewTree, err + } + return newNode, transform.NewTree, nil +} diff --git a/server/expression/implicit_cast.go b/server/expression/implicit_cast.go index fe2474a9fc..6232c9c346 100644 --- a/server/expression/implicit_cast.go +++ b/server/expression/implicit_cast.go @@ -32,6 +32,63 @@ type ImplicitCast struct { var _ sql.Expression = (*ImplicitCast)(nil) +// UnknownCoercion wraps an expression with unknown type to coerce it to a target type. +// Unlike ImplicitCast, this doesn't perform any actual conversion - it just changes the +// reported type since unknown type literals can coerce to any type in PostgreSQL. +type UnknownCoercion struct { + expr sql.Expression + toType *pgtypes.DoltgresType +} + +var _ sql.Expression = (*UnknownCoercion)(nil) + +// NewUnknownCoercion returns a new *UnknownCoercion expression. +func NewUnknownCoercion(expr sql.Expression, toType *pgtypes.DoltgresType) *UnknownCoercion { + return &UnknownCoercion{ + expr: expr, + toType: toType, + } +} + +// Children implements the sql.Expression interface. +func (uc *UnknownCoercion) Children() []sql.Expression { + return []sql.Expression{uc.expr} +} + +// Eval implements the sql.Expression interface. +func (uc *UnknownCoercion) Eval(ctx *sql.Context, row sql.Row) (any, error) { + // Just pass through - unknown type values can coerce to any type + return uc.expr.Eval(ctx, row) +} + +// IsNullable implements the sql.Expression interface. +func (uc *UnknownCoercion) IsNullable() bool { + return uc.expr.IsNullable() +} + +// Resolved implements the sql.Expression interface. +func (uc *UnknownCoercion) Resolved() bool { + return uc.expr.Resolved() +} + +// String implements the sql.Expression interface. +func (uc *UnknownCoercion) String() string { + return uc.expr.String() +} + +// Type implements the sql.Expression interface. +func (uc *UnknownCoercion) Type() sql.Type { + return uc.toType +} + +// WithChildren implements the sql.Expression interface. +func (uc *UnknownCoercion) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(uc, len(children), 1) + } + return NewUnknownCoercion(children[0], uc.toType), nil +} + // NewImplicitCast returns a new *ImplicitCast expression. func NewImplicitCast(expr sql.Expression, fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType) *ImplicitCast { toType = checkForDomainType(toType) diff --git a/server/functions/framework/common_type.go b/server/functions/framework/common_type.go index f06f22f90a..57f301c16d 100644 --- a/server/functions/framework/common_type.go +++ b/server/functions/framework/common_type.go @@ -55,16 +55,22 @@ func FindCommonType(types []*pgtypes.DoltgresType) (*pgtypes.DoltgresType, error if typ.ID == pgtypes.Unknown.ID { continue } else if GetImplicitCast(typ, candidateType) != nil { + // typ can convert to candidateType, so candidateType is at least as general continue } else if GetImplicitCast(candidateType, typ) == nil { return nil, errors.Errorf("cannot find implicit cast function from %s to %s", candidateType.String(), typ.String()) - } else if !preferredTypeFound { + } else { + // candidateType can convert to typ, but not vice versa, so typ is more general + // Per PostgreSQL docs: "If the resolution type can be implicitly converted to the + // other type but not vice-versa, select the other type as the new resolution type." + candidateType = typ if candidateType.IsPreferred { - candidateType = typ + // "Then, if the new resolution type is preferred, stop considering further inputs." preferredTypeFound = true } - } else { - return nil, errors.Errorf("found another preferred candidate type") + } + if preferredTypeFound { + break } } return candidateType, nil diff --git a/testing/bats/types.bats b/testing/bats/types.bats index aaeb6a80b3..bc143e54f0 100644 --- a/testing/bats/types.bats +++ b/testing/bats/types.bats @@ -38,3 +38,37 @@ SQL [[ "$output" =~ '5,{t}' ]] || false [[ "$output" =~ '6,{f}' ]] || false } + +@test 'types: VALUES clause mixed int and decimal' { + # Integer first, then decimal - should resolve to numeric + run query_server -t -c "SELECT * FROM (VALUES(1),(2.01),(3)) v(n);" + [ "$status" -eq 0 ] + [[ "$output" =~ "1" ]] || false + [[ "$output" =~ "2.01" ]] || false + [[ "$output" =~ "3" ]] || false +} + +@test 'types: VALUES clause decimal first then int' { + # Decimal first, then integers - should resolve to numeric + run query_server -t -c "SELECT * FROM (VALUES(1.01),(2),(3)) v(n);" + [ "$status" -eq 0 ] + [[ "$output" =~ "1.01" ]] || false + [[ "$output" =~ "2" ]] || false + [[ "$output" =~ "3" ]] || false +} + +@test 'types: VALUES clause SUM with mixed types' { + # SUM should work directly now that VALUES has correct type + run query_server -t -c "SELECT SUM(n) FROM (VALUES(1),(2.01),(3)) v(n);" + [ "$status" -eq 0 ] + [[ "$output" =~ "6.01" ]] || false +} + +@test 'types: VALUES clause multiple columns mixed types' { + run query_server -t -c "SELECT * FROM (VALUES(1, 'a'), (2.5, 'b')) v(num, str);" + [ "$status" -eq 0 ] + [[ "$output" =~ "1" ]] || false + [[ "$output" =~ "a" ]] || false + [[ "$output" =~ "2.5" ]] || false + [[ "$output" =~ "b" ]] || false +} diff --git a/testing/go/values_statement_test.go b/testing/go/values_statement_test.go index cf995bd301..0607a5e513 100644 --- a/testing/go/values_statement_test.go +++ b/testing/go/values_statement_test.go @@ -53,4 +53,47 @@ var ValuesStatementTests = []ScriptTest{ }, }, }, + { + Name: "VALUES with mixed int and decimal - issue 1648", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + // Integer first, then decimal - should resolve to numeric + Query: `SELECT * FROM (VALUES(1),(2.01),(3)) v(n);`, + Expected: []sql.Row{ + {Numeric("1")}, + {Numeric("2.01")}, + {Numeric("3")}, + }, + }, + { + // Decimal first, then integers - should resolve to numeric + Query: `SELECT * FROM (VALUES(1.01),(2),(3)) v(n);`, + Expected: []sql.Row{ + {Numeric("1.01")}, + {Numeric("2")}, + {Numeric("3")}, + }, + }, + { + // SUM should work directly now that VALUES has correct type + // Note: SUM returns float64 (double precision) for numeric input + Query: `SELECT SUM(n) FROM (VALUES(1),(2.01),(3)) v(n);`, + Expected: []sql.Row{{6.01}}, + }, + }, + }, + { + Name: "VALUES with multiple columns mixed types", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + Query: `SELECT * FROM (VALUES(1, 'a'), (2.5, 'b')) v(num, str);`, + Expected: []sql.Row{ + {Numeric("1"), "a"}, + {Numeric("2.5"), "b"}, + }, + }, + }, + }, }