From 7cdc339d22b682de93cb1930ffb3c632f1715309 Mon Sep 17 00:00:00 2001 From: David Dansby <39511285+codeaucafe@users.noreply.github.com> Date: Thu, 8 Jan 2026 21:51:26 -0800 Subject: [PATCH] feat(analyzer): fix VALUES clause type inference Add ResolveValuesTypes analyzer rule to compute common types across all VALUES rows, not just the first row. Previously, DoltgreSQL would incorrectly use only the first value to determine column types, causing errors when subsequent values had different types like VALUES(1),(2.01),(3). Changes: - Two-pass transformation strategy: first pass transforms VDT nodes with unified types, second pass updates GetField expressions in parent nodes - Use FindCommonType() to resolve types per PostgreSQL rules - Apply ImplicitCast for type conversions and UnknownCoercion for unknown-typed literals - Handle aggregates via getSourceSchema() - Add UnknownCoercion expression type for unknown -> target coercion without conversion Tests: - Add 4 bats integration tests for mixed int/decimal VALUES - Add 3 Go test cases covering int-first, decimal-first, SUM aggregate, and multi-column scenarios Refs: #1648 --- server/analyzer/init.go | 2 + server/analyzer/resolve_values_types.go | 331 ++++++++++++++++++++++ server/expression/implicit_cast.go | 57 ++++ server/functions/framework/common_type.go | 14 +- testing/bats/types.bats | 34 +++ testing/go/values_statement_test.go | 43 +++ 6 files changed, 477 insertions(+), 4 deletions(-) create mode 100644 server/analyzer/resolve_values_types.go diff --git a/server/analyzer/init.go b/server/analyzer/init.go index 1047bb2457..95c99cc05c 100644 --- a/server/analyzer/init.go +++ b/server/analyzer/init.go @@ -49,6 +49,7 @@ const ( ruleId_ValidateCreateSchema // validateCreateSchema ruleId_ResolveAlterColumn // resolveAlterColumn ruleId_ValidateCreateFunction + ruleId_ResolveValuesTypes // resolveValuesTypes ) // Init adds additional rules to the analyzer to handle Doltgres-specific functionality. @@ -56,6 +57,7 @@ func Init() { analyzer.AlwaysBeforeDefault = append(analyzer.AlwaysBeforeDefault, analyzer.Rule{Id: ruleId_ResolveType, Apply: ResolveType}, analyzer.Rule{Id: ruleId_TypeSanitizer, Apply: TypeSanitizer}, + analyzer.Rule{Id: ruleId_ResolveValuesTypes, Apply: ResolveValuesTypes}, analyzer.Rule{Id: ruleId_GenerateForeignKeyName, Apply: generateForeignKeyName}, analyzer.Rule{Id: ruleId_AddDomainConstraints, Apply: AddDomainConstraints}, analyzer.Rule{Id: ruleId_ValidateColumnDefaults, Apply: ValidateColumnDefaults}, diff --git a/server/analyzer/resolve_values_types.go b/server/analyzer/resolve_values_types.go new file mode 100644 index 0000000000..d413e9843a --- /dev/null +++ b/server/analyzer/resolve_values_types.go @@ -0,0 +1,331 @@ +// Copyright 2026 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package analyzer + +import ( + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/analyzer" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql/transform" + + pgexprs "github.com/dolthub/doltgresql/server/expression" + "github.com/dolthub/doltgresql/server/functions/framework" + pgtypes "github.com/dolthub/doltgresql/server/types" +) + +// ResolveValuesTypes determines the common type for each column in a VALUES clause +// by examining all rows, following PostgreSQL's type resolution rules. +// This ensures VALUES(1),(2.01),(3) correctly infers numeric type, not integer. +func ResolveValuesTypes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, scope *plan.Scope, selector analyzer.RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) { + // Track which VDTs we transform so we can update parent nodes + transformedVDTs := make(map[*plan.ValueDerivedTable]sql.Schema) + + // First pass: transform VDTs and record their new schemas + node, same1, err := transform.NodeWithOpaque(node, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { + newNode, same, err := transformValuesNode(n) + if err != nil { + return nil, same, err + } + if !same { + if vdt, ok := newNode.(*plan.ValueDerivedTable); ok { + transformedVDTs[vdt] = vdt.Schema() + } + } + return newNode, same, err + }) + if err != nil { + return nil, transform.SameTree, err + } + + // Second pass: update GetField types in parent nodes that reference transformed VDTs + if len(transformedVDTs) > 0 { + node, _, err = transform.NodeWithOpaque(node, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { + return updateGetFieldTypes(n, transformedVDTs) + }) + if err != nil { + return nil, transform.SameTree, err + } + } + + return node, same1, nil +} + +// getSourceSchema traverses through wrapper nodes (GroupBy, Filter, etc.) to find +// the actual source schema from a VDT or other data source. This is needed because +// nodes like GroupBy produce a different output schema than their input schema. +func getSourceSchema(n sql.Node) sql.Schema { + switch node := n.(type) { + case *plan.GroupBy: + // GroupBy's Schema() returns aggregate output, but we need the source schema + return getSourceSchema(node.Child) + case *plan.Filter: + return getSourceSchema(node.Child) + case *plan.Sort: + return getSourceSchema(node.Child) + case *plan.Limit: + return getSourceSchema(node.Child) + case *plan.Offset: + return getSourceSchema(node.Child) + case *plan.Distinct: + return getSourceSchema(node.Child) + case *plan.SubqueryAlias: + // SubqueryAlias wraps a VDT - get the child's schema + return node.Child.Schema() + case *plan.ValueDerivedTable: + return node.Schema() + default: + // For other nodes, return their schema directly + return n.Schema() + } +} + +// updateGetFieldTypes updates GetField expressions that reference transformed VDT columns +func updateGetFieldTypes(n sql.Node, transformedVDTs map[*plan.ValueDerivedTable]sql.Schema) (sql.Node, transform.TreeIdentity, error) { + // Only handle nodes that have expressions (like Project) + exprNode, ok := n.(sql.Expressioner) + if !ok { + return n, transform.SameTree, nil + } + + // Get the source schema by traversing through wrapper nodes like GroupBy + // This ensures we get the VDT's schema, not the aggregate output schema + var childSchema sql.Schema + switch node := n.(type) { + case *plan.Project: + childSchema = getSourceSchema(node.Child) + case *plan.SubqueryAlias: + childSchema = node.Child.Schema() + default: + return n, transform.SameTree, nil + } + + if childSchema == nil { + return n, transform.SameTree, nil + } + + // Transform expressions to update GetField types (recursively for nested expressions) + exprs := exprNode.Expressions() + newExprs := make([]sql.Expression, len(exprs)) + changed := false + + for i, expr := range exprs { + newExpr, exprChanged, err := updateGetFieldExprRecursive(expr, childSchema) + if err != nil { + return nil, transform.SameTree, err + } + newExprs[i] = newExpr + if exprChanged { + changed = true + } + } + + if !changed { + return n, transform.SameTree, nil + } + + newNode, err := exprNode.WithExpressions(newExprs...) + if err != nil { + return nil, transform.SameTree, err + } + return newNode.(sql.Node), transform.NewTree, nil +} + +// updateGetFieldExprRecursive recursively updates GetField expressions in the expression tree +func updateGetFieldExprRecursive(expr sql.Expression, childSchema sql.Schema) (sql.Expression, bool, error) { + // First try to update if this is a GetField + if _, ok := expr.(*expression.GetField); ok { + return updateGetFieldExpr(expr, childSchema) + } + + // Recursively process children + children := expr.Children() + if len(children) == 0 { + return expr, false, nil + } + + newChildren := make([]sql.Expression, len(children)) + changed := false + for i, child := range children { + newChild, childChanged, err := updateGetFieldExprRecursive(child, childSchema) + if err != nil { + return nil, false, err + } + newChildren[i] = newChild + if childChanged { + changed = true + } + } + + if !changed { + return expr, false, nil + } + + newExpr, err := expr.WithChildren(newChildren...) + if err != nil { + return nil, false, err + } + return newExpr, true, nil +} + +// updateGetFieldExpr updates a GetField expression to use the correct type from the child schema +func updateGetFieldExpr(expr sql.Expression, childSchema sql.Schema) (sql.Expression, bool, error) { + gf, ok := expr.(*expression.GetField) + if !ok { + return expr, false, nil + } + + idx := gf.Index() + // GetField indices are 1-based in GMS planbuilder, so subtract 1 for schema access + schemaIdx := idx - 1 + if schemaIdx < 0 || schemaIdx >= len(childSchema) { + return expr, false, nil + } + + newType := childSchema[schemaIdx].Type + if gf.Type() == newType { + return expr, false, nil + } + + // Create a new GetField with the updated type + newGf := expression.NewGetFieldWithTable( + idx, + int(gf.TableId()), + newType, + gf.Database(), + gf.Table(), + gf.Name(), + gf.IsNullable(), + ) + return newGf, true, nil +} + +// transformValuesNode transforms a VALUES or ValueDerivedTable node to use common types +func transformValuesNode(n sql.Node) (sql.Node, transform.TreeIdentity, error) { + // Handle both ValueDerivedTable and Values nodes + var values *plan.Values + var vdt *plan.ValueDerivedTable + var isVDT bool + + switch v := n.(type) { + case *plan.ValueDerivedTable: + vdt = v + values = v.Values + isVDT = true + case *plan.Values: + values = v + isVDT = false + default: + return n, transform.SameTree, nil + } + + // Skip if no rows or single row (nothing to unify) + if len(values.ExpressionTuples) <= 1 { + return n, transform.SameTree, nil + } + + numCols := len(values.ExpressionTuples[0]) + if numCols == 0 { + return n, transform.SameTree, nil + } + + // Collect types for each column across all rows + columnTypes := make([][]*pgtypes.DoltgresType, numCols) + for colIdx := 0; colIdx < numCols; colIdx++ { + columnTypes[colIdx] = make([]*pgtypes.DoltgresType, len(values.ExpressionTuples)) + for rowIdx, row := range values.ExpressionTuples { + exprType := row[colIdx].Type() + if exprType == nil { + columnTypes[colIdx][rowIdx] = pgtypes.Unknown + } else if pgType, ok := exprType.(*pgtypes.DoltgresType); ok { + columnTypes[colIdx][rowIdx] = pgType + } else { + // Non-DoltgresType encountered - should have been sanitized + // Return unchanged and let TypeSanitizer handle it + return n, transform.SameTree, nil + } + } + } + + // Find common type for each column + commonTypes := make([]*pgtypes.DoltgresType, numCols) + for colIdx := 0; colIdx < numCols; colIdx++ { + commonType, err := framework.FindCommonType(columnTypes[colIdx]) + if err != nil { + return nil, transform.NewTree, err + } + commonTypes[colIdx] = commonType + } + + // Check if any changes are needed + needsChange := false + for colIdx := 0; colIdx < numCols; colIdx++ { + for rowIdx := 0; rowIdx < len(values.ExpressionTuples); rowIdx++ { + if !columnTypes[colIdx][rowIdx].Equals(commonTypes[colIdx]) { + needsChange = true + break + } + } + if needsChange { + break + } + } + + if !needsChange { + return n, transform.SameTree, nil + } + + // Create new expression tuples with implicit casts where needed + newTuples := make([][]sql.Expression, len(values.ExpressionTuples)) + for rowIdx, row := range values.ExpressionTuples { + newTuples[rowIdx] = make([]sql.Expression, numCols) + for colIdx, expr := range row { + fromType := columnTypes[colIdx][rowIdx] + toType := commonTypes[colIdx] + if fromType.Equals(toType) { + newTuples[rowIdx][colIdx] = expr + } else if fromType.ID == pgtypes.Unknown.ID { + // Unknown type can be coerced to any type without explicit cast + // Use UnknownCoercion to report the target type while passing through values + newTuples[rowIdx][colIdx] = pgexprs.NewUnknownCoercion(expr, toType) + } else { + newTuples[rowIdx][colIdx] = pgexprs.NewImplicitCast(expr, fromType, toType) + } + } + } + + // Flatten the new tuples into a single expression slice for WithExpressions + var flatExprs []sql.Expression + for _, row := range newTuples { + flatExprs = append(flatExprs, row...) + } + + if isVDT { + // Use WithExpressions to preserve all VDT fields (name, columns, id, cols) + // while updating the expressions and recalculating the schema + newNode, err := vdt.WithExpressions(flatExprs...) + if err != nil { + return nil, transform.NewTree, err + } + return newNode, transform.NewTree, nil + } + + // For standalone Values node, use WithExpressions as well + newNode, err := values.WithExpressions(flatExprs...) + if err != nil { + return nil, transform.NewTree, err + } + return newNode, transform.NewTree, nil +} diff --git a/server/expression/implicit_cast.go b/server/expression/implicit_cast.go index fe2474a9fc..6232c9c346 100644 --- a/server/expression/implicit_cast.go +++ b/server/expression/implicit_cast.go @@ -32,6 +32,63 @@ type ImplicitCast struct { var _ sql.Expression = (*ImplicitCast)(nil) +// UnknownCoercion wraps an expression with unknown type to coerce it to a target type. +// Unlike ImplicitCast, this doesn't perform any actual conversion - it just changes the +// reported type since unknown type literals can coerce to any type in PostgreSQL. +type UnknownCoercion struct { + expr sql.Expression + toType *pgtypes.DoltgresType +} + +var _ sql.Expression = (*UnknownCoercion)(nil) + +// NewUnknownCoercion returns a new *UnknownCoercion expression. +func NewUnknownCoercion(expr sql.Expression, toType *pgtypes.DoltgresType) *UnknownCoercion { + return &UnknownCoercion{ + expr: expr, + toType: toType, + } +} + +// Children implements the sql.Expression interface. +func (uc *UnknownCoercion) Children() []sql.Expression { + return []sql.Expression{uc.expr} +} + +// Eval implements the sql.Expression interface. +func (uc *UnknownCoercion) Eval(ctx *sql.Context, row sql.Row) (any, error) { + // Just pass through - unknown type values can coerce to any type + return uc.expr.Eval(ctx, row) +} + +// IsNullable implements the sql.Expression interface. +func (uc *UnknownCoercion) IsNullable() bool { + return uc.expr.IsNullable() +} + +// Resolved implements the sql.Expression interface. +func (uc *UnknownCoercion) Resolved() bool { + return uc.expr.Resolved() +} + +// String implements the sql.Expression interface. +func (uc *UnknownCoercion) String() string { + return uc.expr.String() +} + +// Type implements the sql.Expression interface. +func (uc *UnknownCoercion) Type() sql.Type { + return uc.toType +} + +// WithChildren implements the sql.Expression interface. +func (uc *UnknownCoercion) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(uc, len(children), 1) + } + return NewUnknownCoercion(children[0], uc.toType), nil +} + // NewImplicitCast returns a new *ImplicitCast expression. func NewImplicitCast(expr sql.Expression, fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType) *ImplicitCast { toType = checkForDomainType(toType) diff --git a/server/functions/framework/common_type.go b/server/functions/framework/common_type.go index f06f22f90a..57f301c16d 100644 --- a/server/functions/framework/common_type.go +++ b/server/functions/framework/common_type.go @@ -55,16 +55,22 @@ func FindCommonType(types []*pgtypes.DoltgresType) (*pgtypes.DoltgresType, error if typ.ID == pgtypes.Unknown.ID { continue } else if GetImplicitCast(typ, candidateType) != nil { + // typ can convert to candidateType, so candidateType is at least as general continue } else if GetImplicitCast(candidateType, typ) == nil { return nil, errors.Errorf("cannot find implicit cast function from %s to %s", candidateType.String(), typ.String()) - } else if !preferredTypeFound { + } else { + // candidateType can convert to typ, but not vice versa, so typ is more general + // Per PostgreSQL docs: "If the resolution type can be implicitly converted to the + // other type but not vice-versa, select the other type as the new resolution type." + candidateType = typ if candidateType.IsPreferred { - candidateType = typ + // "Then, if the new resolution type is preferred, stop considering further inputs." preferredTypeFound = true } - } else { - return nil, errors.Errorf("found another preferred candidate type") + } + if preferredTypeFound { + break } } return candidateType, nil diff --git a/testing/bats/types.bats b/testing/bats/types.bats index aaeb6a80b3..bc143e54f0 100644 --- a/testing/bats/types.bats +++ b/testing/bats/types.bats @@ -38,3 +38,37 @@ SQL [[ "$output" =~ '5,{t}' ]] || false [[ "$output" =~ '6,{f}' ]] || false } + +@test 'types: VALUES clause mixed int and decimal' { + # Integer first, then decimal - should resolve to numeric + run query_server -t -c "SELECT * FROM (VALUES(1),(2.01),(3)) v(n);" + [ "$status" -eq 0 ] + [[ "$output" =~ "1" ]] || false + [[ "$output" =~ "2.01" ]] || false + [[ "$output" =~ "3" ]] || false +} + +@test 'types: VALUES clause decimal first then int' { + # Decimal first, then integers - should resolve to numeric + run query_server -t -c "SELECT * FROM (VALUES(1.01),(2),(3)) v(n);" + [ "$status" -eq 0 ] + [[ "$output" =~ "1.01" ]] || false + [[ "$output" =~ "2" ]] || false + [[ "$output" =~ "3" ]] || false +} + +@test 'types: VALUES clause SUM with mixed types' { + # SUM should work directly now that VALUES has correct type + run query_server -t -c "SELECT SUM(n) FROM (VALUES(1),(2.01),(3)) v(n);" + [ "$status" -eq 0 ] + [[ "$output" =~ "6.01" ]] || false +} + +@test 'types: VALUES clause multiple columns mixed types' { + run query_server -t -c "SELECT * FROM (VALUES(1, 'a'), (2.5, 'b')) v(num, str);" + [ "$status" -eq 0 ] + [[ "$output" =~ "1" ]] || false + [[ "$output" =~ "a" ]] || false + [[ "$output" =~ "2.5" ]] || false + [[ "$output" =~ "b" ]] || false +} diff --git a/testing/go/values_statement_test.go b/testing/go/values_statement_test.go index cf995bd301..0607a5e513 100644 --- a/testing/go/values_statement_test.go +++ b/testing/go/values_statement_test.go @@ -53,4 +53,47 @@ var ValuesStatementTests = []ScriptTest{ }, }, }, + { + Name: "VALUES with mixed int and decimal - issue 1648", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + // Integer first, then decimal - should resolve to numeric + Query: `SELECT * FROM (VALUES(1),(2.01),(3)) v(n);`, + Expected: []sql.Row{ + {Numeric("1")}, + {Numeric("2.01")}, + {Numeric("3")}, + }, + }, + { + // Decimal first, then integers - should resolve to numeric + Query: `SELECT * FROM (VALUES(1.01),(2),(3)) v(n);`, + Expected: []sql.Row{ + {Numeric("1.01")}, + {Numeric("2")}, + {Numeric("3")}, + }, + }, + { + // SUM should work directly now that VALUES has correct type + // Note: SUM returns float64 (double precision) for numeric input + Query: `SELECT SUM(n) FROM (VALUES(1),(2.01),(3)) v(n);`, + Expected: []sql.Row{{6.01}}, + }, + }, + }, + { + Name: "VALUES with multiple columns mixed types", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + Query: `SELECT * FROM (VALUES(1, 'a'), (2.5, 'b')) v(num, str);`, + Expected: []sql.Row{ + {Numeric("1"), "a"}, + {Numeric("2.5"), "b"}, + }, + }, + }, + }, }