Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions server/ast/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/sirupsen/logrus"

"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
pgtypes "github.com/dolthub/doltgresql/server/types"
)

// nodeAlterTable handles *tree.AlterTable nodes.
Expand Down Expand Up @@ -318,6 +319,10 @@ func nodeAlterTableAlterColumnType(ctx *Context, node *tree.AlterTableAlterColum
return nil, err
}

if resolvedType == pgtypes.Record {
return nil, errors.Errorf(`column "%s" has pseudo-type record`, node.Column.String())
}

return &vitess.DDL{
Action: "alter",
Table: tableName,
Expand Down
4 changes: 4 additions & 0 deletions server/ast/column_table_def.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ func nodeColumnTableDef(ctx *Context, node *tree.ColumnTableDef) (*vitess.Column
return nil, err
}

if resolvedType == pgtypes.Record {
return nil, errors.Errorf(`column "%s" has pseudo-type record`, node.Name.String())
}

var isNull vitess.BoolVal
var isNotNull vitess.BoolVal
switch node.Nullable.Nullability {
Expand Down
6 changes: 6 additions & 0 deletions server/ast/create_domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
pgnodes "github.com/dolthub/doltgresql/server/node"
pgtypes "github.com/dolthub/doltgresql/server/types"
)

// nodeCreateDomain handles *tree.CreateDomain nodes.
Expand All @@ -37,6 +38,11 @@ func nodeCreateDomain(ctx *Context, node *tree.CreateDomain) (vitess.Statement,
if err != nil {
return nil, err
}

if dataType == pgtypes.Record {
return nil, errors.Errorf(`"record" is not a valid base type for a domain`)
}

if node.Collate != "" {
return nil, errors.Errorf("domain collation is not yet supported")
}
Expand Down
6 changes: 6 additions & 0 deletions server/ast/create_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
pgnodes "github.com/dolthub/doltgresql/server/node"
pgtypes "github.com/dolthub/doltgresql/server/types"
)

// nodeCreateType handles *tree.CreateType nodes.
Expand All @@ -43,6 +44,11 @@ func nodeCreateType(ctx *Context, node *tree.CreateType) (vitess.Statement, erro
if err != nil {
return nil, err
}

if dataType == pgtypes.Record {
return nil, errors.Errorf(`column "%s" has pseudo-type record`, t.AttrName)
}

typs[i] = pgnodes.CompositeAsType{
AttrName: t.AttrName,
Typ: dataType,
Expand Down
19 changes: 13 additions & 6 deletions server/ast/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,15 @@ func nodeExpr(ctx *Context, node tree.Expr) (vitess.Expr, error) {
Expression: pgexprs.NewInSubquery(),
Children: vitess.Exprs{left, right},
}, nil
default:
return nil, errors.Errorf("right side of IN expression is not a tuple or subquery, got %T", right)
case vitess.InjectedExpr:
if _, ok := right.Expression.(*pgexprs.RecordExpr); ok {
return vitess.InjectedExpr{
Expression: pgexprs.NewInTuple(),
Children: vitess.Exprs{left, vitess.ValTuple(right.Children)},
}, nil
}
}
return nil, errors.Errorf("right side of IN expression is not a tuple or subquery, got %T", right)
case tree.NotIn:
innerExpr := vitess.InjectedExpr{
Expression: pgexprs.NewInTuple(),
Expand Down Expand Up @@ -776,15 +782,16 @@ func nodeExpr(ctx *Context, node tree.Expr) (vitess.Expr, error) {
if len(node.Labels) > 0 {
return nil, errors.Errorf("tuple labels are not yet supported")
}
if node.Row {
return nil, errors.Errorf("ROW keyword for tuples not yet supported")
}

valTuple, err := nodeExprs(ctx, node.Exprs)
if err != nil {
return nil, err
}
return vitess.ValTuple(valTuple), nil

return vitess.InjectedExpr{
Expression: pgexprs.NewRecordExpr(),
Children: valTuple,
}, nil
case *tree.TupleStar:
return nil, errors.Errorf("(E).* is not yet supported")
case *tree.UnaryExpr:
Expand Down
2 changes: 2 additions & 0 deletions server/ast/resolvable_type_reference.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ func nodeResolvableTypeReference(ctx *Context, typ tree.ResolvableTypeReference)
return nil, nil, errors.Errorf("geography types are not yet supported")
} else {
switch columnType.Oid() {
case oid.T_record:
resolvedType = pgtypes.Record
case oid.T_bool:
resolvedType = pgtypes.Bool
case oid.T_bytea:
Expand Down
154 changes: 154 additions & 0 deletions server/compare/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package compare

import (
"fmt"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"

"github.com/dolthub/doltgresql/server/functions/framework"
pgtypes "github.com/dolthub/doltgresql/server/types"
)

// CompareRecords compares two record values using the specified comparison operator, |op| and returns a result
// indicating if the comparison was true, false, or indeterminate (nil).
//
// More info on rules for comparing records:
// https://www.postgresql.org/docs/current/functions-comparisons.html#ROW-WISE-COMPARISON
func CompareRecords(ctx *sql.Context, op framework.Operator, v1 interface{}, v2 interface{}) (result any, err error) {
leftRecord, rightRecord, err := checkRecordArgs(v1, v2)
if err != nil {
return nil, err
}

hasNull := false
hasEqualFields := true

for i := 0; i < len(leftRecord); i++ {
typ1 := leftRecord[i].Type
typ2 := rightRecord[i].Type

// NULL values are by definition not comparable, so they need special handling depending
// on what type of comparison we're performing.
if leftRecord[i].Value == nil || rightRecord[i].Value == nil {
switch op {
case framework.Operator_BinaryEqual:
// If we're comparing for equality, then any presence of a NULL value means
// we don't have enough information to determine equality, so return nil.
return nil, nil

case framework.Operator_BinaryLessThan, framework.Operator_BinaryGreaterThan,
framework.Operator_BinaryLessOrEqual, framework.Operator_BinaryGreaterOrEqual:
// If we haven't seen a prior field with non-equivalent values, then we
// don't have enough certainty to make a comparison, so return nil.
if hasEqualFields {
return nil, nil
}
}

// Otherwise, mark that we've seen a NULL and skip over it
hasNull = true
continue
}

leftLiteral := expression.NewLiteral(leftRecord[i].Value, typ1)
rightLiteral := expression.NewLiteral(rightRecord[i].Value, typ2)

// For >= and <=, we need to distinguish between < and = (and > and =). Records
// are compared by evaluating each field, in order of significance, so for >= if
// the field is greater than, then we can stop comparing and return true immediately.
// If the field is equal, then we need to look at the next field. For this reason,
// we have to break >= and <= into separate comparisons for > or < and =.
switch op {
case framework.Operator_BinaryLessThan, framework.Operator_BinaryLessOrEqual:
if res, err := callComparisonFunction(ctx, framework.Operator_BinaryLessThan, leftLiteral, rightLiteral); err != nil {
return false, err
} else if res == true {
return true, nil
}
case framework.Operator_BinaryGreaterThan, framework.Operator_BinaryGreaterOrEqual:
if res, err := callComparisonFunction(ctx, framework.Operator_BinaryGreaterThan, leftLiteral, rightLiteral); err != nil {
return false, err
} else if res == true {
return true, nil
}
}

// After we've determined > and <, we can look at the equality comparison. For < and >, we've already returned
// true if that initial comparison was true. Now we need to determine if the two fields are equal, in which case
// we continue on to check the next field. If the two fields are NOT equal, then we can return false immediately.
switch op {
case framework.Operator_BinaryGreaterOrEqual, framework.Operator_BinaryLessOrEqual, framework.Operator_BinaryEqual:
if res, err := callComparisonFunction(ctx, framework.Operator_BinaryEqual, leftLiteral, rightLiteral); err != nil {
return false, err
} else if res == false {
return false, nil
}
case framework.Operator_BinaryNotEqual:
if res, err := callComparisonFunction(ctx, framework.Operator_BinaryNotEqual, leftLiteral, rightLiteral); err != nil {
return false, err
} else if res == true {
return true, nil
}
case framework.Operator_BinaryLessThan, framework.Operator_BinaryGreaterThan:
if res, err := callComparisonFunction(ctx, framework.Operator_BinaryEqual, leftLiteral, rightLiteral); err != nil {
return false, err
} else if res == false {
// For < and >, we still need to check equality to know if we need to continue checking additional fields.
// If
hasEqualFields = false
}
default:
return false, fmt.Errorf("unsupportd binary operator: %s", op)
}
}

// If the records contain any NULL fields, but all non-NULL fields are equal, then we
// don't have enough certainty to return a result.
if hasNull && hasEqualFields {
return nil, nil
}

return true, nil
}

// checkRecordArgs asserts that |v1| and |v2| are both []pgtypes.RecordValue, and that they have the same number of
// elements, then returns them. If any problems were detected, an error is returnd instead.
func checkRecordArgs(v1, v2 interface{}) (leftRecord, rightRecord []pgtypes.RecordValue, err error) {
leftRecord, ok1 := v1.([]pgtypes.RecordValue)
rightRecord, ok2 := v2.([]pgtypes.RecordValue)
if !ok1 {
return nil, nil, fmt.Errorf("expected a RecordValue, but got %T", v1)
} else if !ok2 {
return nil, nil, fmt.Errorf("expected a RecordValue, but got %T", v2)
}

if len(leftRecord) != len(rightRecord) {
return nil, nil, fmt.Errorf("unequal number of entries in row expressions")
}

return leftRecord, rightRecord, nil
}

// callComparisonFunction invokes the binary comparison function for the specified operator |op| with the two arguments
// |leftLiteral| and |rightLiteral|. The result and any error are returned.
func callComparisonFunction(ctx *sql.Context, op framework.Operator, leftLiteral, rightLiteral sql.Expression) (result any, err error) {
intermediateFunction := framework.GetBinaryFunction(op)
compiledFunction := intermediateFunction.Compile(
"_internal_record_comparison_function", leftLiteral, rightLiteral)
return compiledFunction.Eval(ctx, nil)
}
15 changes: 11 additions & 4 deletions server/expression/in_tuple.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,18 @@ func (it *InTuple) WithResolvedChildren(children []any) (any, error) {
if !ok {
return nil, errors.Errorf("expected vitess child to be an expression but has type `%T`", children[0])
}
right, ok := children[1].(expression.Tuple)
if !ok {
return nil, errors.Errorf("expected vitess child to be an expression tuple but has type `%T`", children[1])

switch right := children[1].(type) {
case expression.Tuple:
return it.WithChildren(left, right)
case *RecordExpr:
// TODO: For now, if we see a RecordExpr come in, we convert it to a vitess Tuple representation, so that
// the existing in_tuple code can work with it. Alternatively, we could change in_tuple to always
// work directly with a Record expression.
return it.WithChildren(left, expression.Tuple(right.exprs))
default:
return nil, errors.Errorf("expected child to be a RecordExpr or vitess Tuple but has type `%T`", children[1])
}
return it.WithChildren(left, right)
}

// Left implements the expression.BinaryExpression interface.
Expand Down
Loading