diff --git a/server/ast/alter_table.go b/server/ast/alter_table.go index ca4ec37768..a7299d10ef 100644 --- a/server/ast/alter_table.go +++ b/server/ast/alter_table.go @@ -22,6 +22,7 @@ import ( "github.com/sirupsen/logrus" "github.com/dolthub/doltgresql/postgres/parser/sem/tree" + pgtypes "github.com/dolthub/doltgresql/server/types" ) // nodeAlterTable handles *tree.AlterTable nodes. @@ -318,6 +319,10 @@ func nodeAlterTableAlterColumnType(ctx *Context, node *tree.AlterTableAlterColum return nil, err } + if resolvedType == pgtypes.Record { + return nil, errors.Errorf(`column "%s" has pseudo-type record`, node.Column.String()) + } + return &vitess.DDL{ Action: "alter", Table: tableName, diff --git a/server/ast/column_table_def.go b/server/ast/column_table_def.go index c5b43d9c33..aab63878ae 100644 --- a/server/ast/column_table_def.go +++ b/server/ast/column_table_def.go @@ -42,6 +42,10 @@ func nodeColumnTableDef(ctx *Context, node *tree.ColumnTableDef) (*vitess.Column return nil, err } + if resolvedType == pgtypes.Record { + return nil, errors.Errorf(`column "%s" has pseudo-type record`, node.Name.String()) + } + var isNull vitess.BoolVal var isNotNull vitess.BoolVal switch node.Nullable.Nullability { diff --git a/server/ast/create_domain.go b/server/ast/create_domain.go index dcd5cbd880..73ca412dac 100644 --- a/server/ast/create_domain.go +++ b/server/ast/create_domain.go @@ -22,6 +22,7 @@ import ( "github.com/dolthub/doltgresql/postgres/parser/sem/tree" pgnodes "github.com/dolthub/doltgresql/server/node" + pgtypes "github.com/dolthub/doltgresql/server/types" ) // nodeCreateDomain handles *tree.CreateDomain nodes. @@ -37,6 +38,11 @@ func nodeCreateDomain(ctx *Context, node *tree.CreateDomain) (vitess.Statement, if err != nil { return nil, err } + + if dataType == pgtypes.Record { + return nil, errors.Errorf(`"record" is not a valid base type for a domain`) + } + if node.Collate != "" { return nil, errors.Errorf("domain collation is not yet supported") } diff --git a/server/ast/create_type.go b/server/ast/create_type.go index 601aca6c81..1b96d3911e 100644 --- a/server/ast/create_type.go +++ b/server/ast/create_type.go @@ -21,6 +21,7 @@ import ( "github.com/dolthub/doltgresql/postgres/parser/sem/tree" pgnodes "github.com/dolthub/doltgresql/server/node" + pgtypes "github.com/dolthub/doltgresql/server/types" ) // nodeCreateType handles *tree.CreateType nodes. @@ -43,6 +44,11 @@ func nodeCreateType(ctx *Context, node *tree.CreateType) (vitess.Statement, erro if err != nil { return nil, err } + + if dataType == pgtypes.Record { + return nil, errors.Errorf(`column "%s" has pseudo-type record`, t.AttrName) + } + typs[i] = pgnodes.CompositeAsType{ AttrName: t.AttrName, Typ: dataType, diff --git a/server/ast/expr.go b/server/ast/expr.go index 6727201a51..b3a5c288db 100644 --- a/server/ast/expr.go +++ b/server/ast/expr.go @@ -373,9 +373,15 @@ func nodeExpr(ctx *Context, node tree.Expr) (vitess.Expr, error) { Expression: pgexprs.NewInSubquery(), Children: vitess.Exprs{left, right}, }, nil - default: - return nil, errors.Errorf("right side of IN expression is not a tuple or subquery, got %T", right) + case vitess.InjectedExpr: + if _, ok := right.Expression.(*pgexprs.RecordExpr); ok { + return vitess.InjectedExpr{ + Expression: pgexprs.NewInTuple(), + Children: vitess.Exprs{left, vitess.ValTuple(right.Children)}, + }, nil + } } + return nil, errors.Errorf("right side of IN expression is not a tuple or subquery, got %T", right) case tree.NotIn: innerExpr := vitess.InjectedExpr{ Expression: pgexprs.NewInTuple(), @@ -776,15 +782,16 @@ func nodeExpr(ctx *Context, node tree.Expr) (vitess.Expr, error) { if len(node.Labels) > 0 { return nil, errors.Errorf("tuple labels are not yet supported") } - if node.Row { - return nil, errors.Errorf("ROW keyword for tuples not yet supported") - } valTuple, err := nodeExprs(ctx, node.Exprs) if err != nil { return nil, err } - return vitess.ValTuple(valTuple), nil + + return vitess.InjectedExpr{ + Expression: pgexprs.NewRecordExpr(), + Children: valTuple, + }, nil case *tree.TupleStar: return nil, errors.Errorf("(E).* is not yet supported") case *tree.UnaryExpr: diff --git a/server/ast/resolvable_type_reference.go b/server/ast/resolvable_type_reference.go index 08867a466b..a82dc32586 100755 --- a/server/ast/resolvable_type_reference.go +++ b/server/ast/resolvable_type_reference.go @@ -72,6 +72,8 @@ func nodeResolvableTypeReference(ctx *Context, typ tree.ResolvableTypeReference) return nil, nil, errors.Errorf("geography types are not yet supported") } else { switch columnType.Oid() { + case oid.T_record: + resolvedType = pgtypes.Record case oid.T_bool: resolvedType = pgtypes.Bool case oid.T_bytea: diff --git a/server/compare/utils.go b/server/compare/utils.go new file mode 100644 index 0000000000..e957779b9b --- /dev/null +++ b/server/compare/utils.go @@ -0,0 +1,154 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compare + +import ( + "fmt" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + + "github.com/dolthub/doltgresql/server/functions/framework" + pgtypes "github.com/dolthub/doltgresql/server/types" +) + +// CompareRecords compares two record values using the specified comparison operator, |op| and returns a result +// indicating if the comparison was true, false, or indeterminate (nil). +// +// More info on rules for comparing records: +// https://www.postgresql.org/docs/current/functions-comparisons.html#ROW-WISE-COMPARISON +func CompareRecords(ctx *sql.Context, op framework.Operator, v1 interface{}, v2 interface{}) (result any, err error) { + leftRecord, rightRecord, err := checkRecordArgs(v1, v2) + if err != nil { + return nil, err + } + + hasNull := false + hasEqualFields := true + + for i := 0; i < len(leftRecord); i++ { + typ1 := leftRecord[i].Type + typ2 := rightRecord[i].Type + + // NULL values are by definition not comparable, so they need special handling depending + // on what type of comparison we're performing. + if leftRecord[i].Value == nil || rightRecord[i].Value == nil { + switch op { + case framework.Operator_BinaryEqual: + // If we're comparing for equality, then any presence of a NULL value means + // we don't have enough information to determine equality, so return nil. + return nil, nil + + case framework.Operator_BinaryLessThan, framework.Operator_BinaryGreaterThan, + framework.Operator_BinaryLessOrEqual, framework.Operator_BinaryGreaterOrEqual: + // If we haven't seen a prior field with non-equivalent values, then we + // don't have enough certainty to make a comparison, so return nil. + if hasEqualFields { + return nil, nil + } + } + + // Otherwise, mark that we've seen a NULL and skip over it + hasNull = true + continue + } + + leftLiteral := expression.NewLiteral(leftRecord[i].Value, typ1) + rightLiteral := expression.NewLiteral(rightRecord[i].Value, typ2) + + // For >= and <=, we need to distinguish between < and = (and > and =). Records + // are compared by evaluating each field, in order of significance, so for >= if + // the field is greater than, then we can stop comparing and return true immediately. + // If the field is equal, then we need to look at the next field. For this reason, + // we have to break >= and <= into separate comparisons for > or < and =. + switch op { + case framework.Operator_BinaryLessThan, framework.Operator_BinaryLessOrEqual: + if res, err := callComparisonFunction(ctx, framework.Operator_BinaryLessThan, leftLiteral, rightLiteral); err != nil { + return false, err + } else if res == true { + return true, nil + } + case framework.Operator_BinaryGreaterThan, framework.Operator_BinaryGreaterOrEqual: + if res, err := callComparisonFunction(ctx, framework.Operator_BinaryGreaterThan, leftLiteral, rightLiteral); err != nil { + return false, err + } else if res == true { + return true, nil + } + } + + // After we've determined > and <, we can look at the equality comparison. For < and >, we've already returned + // true if that initial comparison was true. Now we need to determine if the two fields are equal, in which case + // we continue on to check the next field. If the two fields are NOT equal, then we can return false immediately. + switch op { + case framework.Operator_BinaryGreaterOrEqual, framework.Operator_BinaryLessOrEqual, framework.Operator_BinaryEqual: + if res, err := callComparisonFunction(ctx, framework.Operator_BinaryEqual, leftLiteral, rightLiteral); err != nil { + return false, err + } else if res == false { + return false, nil + } + case framework.Operator_BinaryNotEqual: + if res, err := callComparisonFunction(ctx, framework.Operator_BinaryNotEqual, leftLiteral, rightLiteral); err != nil { + return false, err + } else if res == true { + return true, nil + } + case framework.Operator_BinaryLessThan, framework.Operator_BinaryGreaterThan: + if res, err := callComparisonFunction(ctx, framework.Operator_BinaryEqual, leftLiteral, rightLiteral); err != nil { + return false, err + } else if res == false { + // For < and >, we still need to check equality to know if we need to continue checking additional fields. + // If + hasEqualFields = false + } + default: + return false, fmt.Errorf("unsupportd binary operator: %s", op) + } + } + + // If the records contain any NULL fields, but all non-NULL fields are equal, then we + // don't have enough certainty to return a result. + if hasNull && hasEqualFields { + return nil, nil + } + + return true, nil +} + +// checkRecordArgs asserts that |v1| and |v2| are both []pgtypes.RecordValue, and that they have the same number of +// elements, then returns them. If any problems were detected, an error is returnd instead. +func checkRecordArgs(v1, v2 interface{}) (leftRecord, rightRecord []pgtypes.RecordValue, err error) { + leftRecord, ok1 := v1.([]pgtypes.RecordValue) + rightRecord, ok2 := v2.([]pgtypes.RecordValue) + if !ok1 { + return nil, nil, fmt.Errorf("expected a RecordValue, but got %T", v1) + } else if !ok2 { + return nil, nil, fmt.Errorf("expected a RecordValue, but got %T", v2) + } + + if len(leftRecord) != len(rightRecord) { + return nil, nil, fmt.Errorf("unequal number of entries in row expressions") + } + + return leftRecord, rightRecord, nil +} + +// callComparisonFunction invokes the binary comparison function for the specified operator |op| with the two arguments +// |leftLiteral| and |rightLiteral|. The result and any error are returned. +func callComparisonFunction(ctx *sql.Context, op framework.Operator, leftLiteral, rightLiteral sql.Expression) (result any, err error) { + intermediateFunction := framework.GetBinaryFunction(op) + compiledFunction := intermediateFunction.Compile( + "_internal_record_comparison_function", leftLiteral, rightLiteral) + return compiledFunction.Eval(ctx, nil) +} diff --git a/server/expression/in_tuple.go b/server/expression/in_tuple.go index 3dae67a142..89e9259cdf 100644 --- a/server/expression/in_tuple.go +++ b/server/expression/in_tuple.go @@ -242,11 +242,18 @@ func (it *InTuple) WithResolvedChildren(children []any) (any, error) { if !ok { return nil, errors.Errorf("expected vitess child to be an expression but has type `%T`", children[0]) } - right, ok := children[1].(expression.Tuple) - if !ok { - return nil, errors.Errorf("expected vitess child to be an expression tuple but has type `%T`", children[1]) + + switch right := children[1].(type) { + case expression.Tuple: + return it.WithChildren(left, right) + case *RecordExpr: + // TODO: For now, if we see a RecordExpr come in, we convert it to a vitess Tuple representation, so that + // the existing in_tuple code can work with it. Alternatively, we could change in_tuple to always + // work directly with a Record expression. + return it.WithChildren(left, expression.Tuple(right.exprs)) + default: + return nil, errors.Errorf("expected child to be a RecordExpr or vitess Tuple but has type `%T`", children[1]) } - return it.WithChildren(left, right) } // Left implements the expression.BinaryExpression interface. diff --git a/server/expression/record.go b/server/expression/record.go new file mode 100644 index 0000000000..45628247b1 --- /dev/null +++ b/server/expression/record.go @@ -0,0 +1,110 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "fmt" + + "github.com/cockroachdb/errors" + "github.com/dolthub/go-mysql-server/sql" + vitess "github.com/dolthub/vitess/go/vt/sqlparser" + + pgtypes "github.com/dolthub/doltgresql/server/types" +) + +// NewRecordExpr creates a new record expression. +func NewRecordExpr() *RecordExpr { + return &RecordExpr{} +} + +// RecordExpr is a set of sql.Expressions wrapped together in a single value. +type RecordExpr struct { + exprs []sql.Expression +} + +var _ sql.Expression = (*RecordExpr)(nil) +var _ vitess.Injectable = (*RecordExpr)(nil) + +// Resolved implements the sql.Expression interface. +func (t *RecordExpr) Resolved() bool { + for _, expr := range t.exprs { + if !expr.Resolved() { + return false + } + } + return true +} + +// String implements the sql.Expression interface. +func (t *RecordExpr) String() string { + return "RECORD EXPR" +} + +// Type implements the sql.Expression interface. +func (t *RecordExpr) Type() sql.Type { + return pgtypes.Record +} + +// IsNullable implements the sql.Expression interface. +func (t *RecordExpr) IsNullable() bool { + return false +} + +// Eval implements the sql.Expression interface. +func (t *RecordExpr) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + vals := make([]pgtypes.RecordValue, len(t.exprs)) + for i, expr := range t.exprs { + val, err := expr.Eval(ctx, row) + if err != nil { + return nil, err + } + + typ, ok := expr.Type().(*pgtypes.DoltgresType) + if !ok { + return nil, fmt.Errorf("expected a DoltgresType, but got %T", expr.Type()) + } + vals[i] = pgtypes.RecordValue{ + Value: val, + Type: typ, + } + } + + return vals, nil +} + +// Children implements the sql.Expression interface. +func (t *RecordExpr) Children() []sql.Expression { + return t.exprs +} + +// WithChildren implements the sql.Expression interface. +func (t *RecordExpr) WithChildren(children ...sql.Expression) (sql.Expression, error) { + tCopy := *t + tCopy.exprs = children + return &tCopy, nil +} + +// WithResolvedChildren implements the vitess.Injectable interface +func (t *RecordExpr) WithResolvedChildren(children []any) (any, error) { + newExpressions := make([]sql.Expression, len(children)) + for i, resolvedChild := range children { + resolvedExpression, ok := resolvedChild.(sql.Expression) + if !ok { + return nil, errors.Errorf("expected vitess child to be an expression but has type `%T`", resolvedChild) + } + newExpressions[i] = resolvedExpression + } + return t.WithChildren(newExpressions...) +} diff --git a/server/functions/binary/equal.go b/server/functions/binary/equal.go index ed3bbbcd51..5b0db87bf2 100644 --- a/server/functions/binary/equal.go +++ b/server/functions/binary/equal.go @@ -23,6 +23,7 @@ import ( "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/postgres/parser/duration" "github.com/dolthub/doltgresql/postgres/parser/uuid" + "github.com/dolthub/doltgresql/server/compare" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -61,6 +62,7 @@ func initBinaryEqual() { framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, oideq) framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, texteqname) framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, text_eq) + framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, record_eq) framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, time_eq) framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, timestamp_eq_date) framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, timestamp_eq) @@ -423,6 +425,17 @@ var text_eq = framework.Function2{ }, } +// record_eq represents the PostgreSQL function of the same name, taking the same parameters. +var record_eq = framework.Function2{ + Name: "record_eq", + Return: pgtypes.Bool, + Parameters: [2]*pgtypes.DoltgresType{pgtypes.Record, pgtypes.Record}, + Strict: true, + Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { + return compare.CompareRecords(ctx, framework.Operator_BinaryEqual, val1, val2) + }, +} + // time_eq represents the PostgreSQL function of the same name, taking the same parameters. var time_eq = framework.Function2{ Name: "time_eq", diff --git a/server/functions/binary/greater.go b/server/functions/binary/greater.go index d1fccd8ddf..19804dcfdf 100644 --- a/server/functions/binary/greater.go +++ b/server/functions/binary/greater.go @@ -24,6 +24,7 @@ import ( "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/postgres/parser/duration" "github.com/dolthub/doltgresql/postgres/parser/uuid" + "github.com/dolthub/doltgresql/server/compare" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -70,6 +71,7 @@ func initBinaryGreaterThan() { framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterThan, timestamptz_gt_timestamp) framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterThan, timestamptz_gt) framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterThan, timetz_gt) + framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterThan, record_gt) framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterThan, uuid_gt) } @@ -517,6 +519,17 @@ var timetz_gt = framework.Function2{ }, } +// record_gt represents the PostgreSQL function of the same name, taking the same parameters. +var record_gt = framework.Function2{ + Name: "record_gt", + Return: pgtypes.Bool, + Parameters: [2]*pgtypes.DoltgresType{pgtypes.Record, pgtypes.Record}, + Strict: true, + Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { + return compare.CompareRecords(ctx, framework.Operator_BinaryGreaterThan, val1, val2) + }, +} + // uuid_gt represents the PostgreSQL function of the same name, taking the same parameters. var uuid_gt = framework.Function2{ Name: "uuid_gt", diff --git a/server/functions/binary/greater_equal.go b/server/functions/binary/greater_equal.go index 8c648cb29d..08fcca3e0b 100644 --- a/server/functions/binary/greater_equal.go +++ b/server/functions/binary/greater_equal.go @@ -24,6 +24,7 @@ import ( "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/postgres/parser/duration" "github.com/dolthub/doltgresql/postgres/parser/uuid" + "github.com/dolthub/doltgresql/server/compare" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -70,6 +71,7 @@ func initBinaryGreaterOrEqual() { framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterOrEqual, timestamptz_ge_timestamp) framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterOrEqual, timestamptz_ge) framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterOrEqual, timetz_ge) + framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterOrEqual, record_ge) framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterOrEqual, uuid_ge) } @@ -517,6 +519,17 @@ var timetz_ge = framework.Function2{ }, } +// record_ge represents the PostgreSQL function of the same name, taking the same parameters. +var record_ge = framework.Function2{ + Name: "record_ge", + Return: pgtypes.Bool, + Parameters: [2]*pgtypes.DoltgresType{pgtypes.Record, pgtypes.Record}, + Strict: true, + Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { + return compare.CompareRecords(ctx, framework.Operator_BinaryGreaterOrEqual, val1, val2) + }, +} + // uuid_ge represents the PostgreSQL function of the same name, taking the same parameters. var uuid_ge = framework.Function2{ Name: "uuid_ge", diff --git a/server/functions/binary/less.go b/server/functions/binary/less.go index fcde7e8974..70993d5253 100644 --- a/server/functions/binary/less.go +++ b/server/functions/binary/less.go @@ -24,6 +24,7 @@ import ( "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/postgres/parser/duration" "github.com/dolthub/doltgresql/postgres/parser/uuid" + "github.com/dolthub/doltgresql/server/compare" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -70,6 +71,7 @@ func initBinaryLessThan() { framework.RegisterBinaryFunction(framework.Operator_BinaryLessThan, timestamptz_lt_timestamp) framework.RegisterBinaryFunction(framework.Operator_BinaryLessThan, timestamptz_lt) framework.RegisterBinaryFunction(framework.Operator_BinaryLessThan, timetz_lt) + framework.RegisterBinaryFunction(framework.Operator_BinaryLessThan, record_lt) framework.RegisterBinaryFunction(framework.Operator_BinaryLessThan, uuid_lt) } @@ -517,6 +519,17 @@ var timetz_lt = framework.Function2{ }, } +// record_lt represents the PostgreSQL function of the same name, taking the same parameters. +var record_lt = framework.Function2{ + Name: "record_lt", + Return: pgtypes.Bool, + Parameters: [2]*pgtypes.DoltgresType{pgtypes.Record, pgtypes.Record}, + Strict: true, + Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { + return compare.CompareRecords(ctx, framework.Operator_BinaryLessThan, val1, val2) + }, +} + // uuid_lt represents the PostgreSQL function of the same name, taking the same parameters. var uuid_lt = framework.Function2{ Name: "uuid_lt", diff --git a/server/functions/binary/less_equal.go b/server/functions/binary/less_equal.go index c3e8e2e674..4710b4a011 100644 --- a/server/functions/binary/less_equal.go +++ b/server/functions/binary/less_equal.go @@ -24,6 +24,7 @@ import ( "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/postgres/parser/duration" "github.com/dolthub/doltgresql/postgres/parser/uuid" + "github.com/dolthub/doltgresql/server/compare" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -70,6 +71,7 @@ func initBinaryLessOrEqual() { framework.RegisterBinaryFunction(framework.Operator_BinaryLessOrEqual, timestamptz_le_timestamp) framework.RegisterBinaryFunction(framework.Operator_BinaryLessOrEqual, timestamptz_le) framework.RegisterBinaryFunction(framework.Operator_BinaryLessOrEqual, timetz_le) + framework.RegisterBinaryFunction(framework.Operator_BinaryLessOrEqual, record_le) framework.RegisterBinaryFunction(framework.Operator_BinaryLessOrEqual, uuid_le) } @@ -517,6 +519,17 @@ var timetz_le = framework.Function2{ }, } +// record_le represents the PostgreSQL function of the same name, taking the same parameters. +var record_le = framework.Function2{ + Name: "record_le", + Return: pgtypes.Bool, + Parameters: [2]*pgtypes.DoltgresType{pgtypes.Record, pgtypes.Record}, + Strict: true, + Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { + return compare.CompareRecords(ctx, framework.Operator_BinaryLessThan, val1, val2) + }, +} + // uuid_le represents the PostgreSQL function of the same name, taking the same parameters. var uuid_le = framework.Function2{ Name: "uuid_le", diff --git a/server/functions/binary/not_equal.go b/server/functions/binary/not_equal.go index 1425d68879..9566125a33 100644 --- a/server/functions/binary/not_equal.go +++ b/server/functions/binary/not_equal.go @@ -23,6 +23,7 @@ import ( "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/postgres/parser/duration" "github.com/dolthub/doltgresql/postgres/parser/uuid" + "github.com/dolthub/doltgresql/server/compare" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -69,6 +70,7 @@ func initBinaryNotEqual() { framework.RegisterBinaryFunction(framework.Operator_BinaryNotEqual, timestamptz_ne_timestamp) framework.RegisterBinaryFunction(framework.Operator_BinaryNotEqual, timestamptz_ne) framework.RegisterBinaryFunction(framework.Operator_BinaryNotEqual, timetz_ne) + framework.RegisterBinaryFunction(framework.Operator_BinaryNotEqual, record_ne) framework.RegisterBinaryFunction(framework.Operator_BinaryNotEqual, uuid_ne) framework.RegisterBinaryFunction(framework.Operator_BinaryNotEqual, xidneqint4) framework.RegisterBinaryFunction(framework.Operator_BinaryNotEqual, xidneq) @@ -518,6 +520,17 @@ var timetz_ne = framework.Function2{ }, } +// record_ne represents the PostgreSQL function of the same name, taking the same parameters. +var record_ne = framework.Function2{ + Name: "record_ne", + Return: pgtypes.Bool, + Parameters: [2]*pgtypes.DoltgresType{pgtypes.Record, pgtypes.Record}, + Strict: true, + Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { + return compare.CompareRecords(ctx, framework.Operator_BinaryNotEqual, val1, val2) + }, +} + // uuid_ne represents the PostgreSQL function of the same name, taking the same parameters. var uuid_ne = framework.Function2{ Name: "uuid_ne", diff --git a/server/functions/record.go b/server/functions/record.go index 45c5395e1b..8a1ebc4389 100644 --- a/server/functions/record.go +++ b/server/functions/record.go @@ -15,9 +15,9 @@ package functions import ( - "github.com/dolthub/go-mysql-server/sql" + "fmt" - "github.com/dolthub/doltgresql/utils" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -40,8 +40,7 @@ var record_in = framework.Function3{ Parameters: [3]*pgtypes.DoltgresType{pgtypes.Cstring, pgtypes.Oid, pgtypes.Int32}, Strict: true, Callable: func(ctx *sql.Context, _ [4]*pgtypes.DoltgresType, val1, val2, val3 any) (any, error) { - //typOid := val2.(id.Id) - return val1.(string), nil + return nil, fmt.Errorf("record_in not implemented") }, } @@ -51,9 +50,12 @@ var record_out = framework.Function1{ Return: pgtypes.Cstring, Parameters: [1]*pgtypes.DoltgresType{pgtypes.Record}, Strict: true, - Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { - // TODO - return val.(string), nil + Callable: func(ctx *sql.Context, t [2]*pgtypes.DoltgresType, val any) (any, error) { + values, ok := val.([]pgtypes.RecordValue) + if !ok { + return nil, fmt.Errorf("expected []any, but got %T", val) + } + return pgtypes.RecordToString(ctx, values) }, } @@ -64,14 +66,7 @@ var record_recv = framework.Function3{ Parameters: [3]*pgtypes.DoltgresType{pgtypes.Internal, pgtypes.Oid, pgtypes.Int32}, Strict: true, Callable: func(ctx *sql.Context, _ [4]*pgtypes.DoltgresType, val1, val2, val3 any) (any, error) { - // TODO - // typOid := val2.(id.Id) - data := val1.([]byte) - if len(data) == 0 { - return nil, nil - } - reader := utils.NewReader(data) - return reader.String(), nil + return nil, fmt.Errorf("record_recv not implemented") }, } @@ -81,12 +76,17 @@ var record_send = framework.Function1{ Return: pgtypes.Bytea, Parameters: [1]*pgtypes.DoltgresType{pgtypes.Record}, Strict: true, - Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { - // TODO - str := val.(string) - writer := utils.NewWriter(uint64(len(str) + 4)) - writer.String(str) - return writer.Data(), nil + Callable: func(ctx *sql.Context, t [2]*pgtypes.DoltgresType, val any) (any, error) { + values, ok := val.([]pgtypes.RecordValue) + if !ok { + return nil, fmt.Errorf("expected []any, but got %T", val) + } + output, err := pgtypes.RecordToString(ctx, values) + if err != nil { + return nil, err + } + + return []byte(output.(string)), nil }, } diff --git a/server/types/record.go b/server/types/record.go index 7d310c9ce1..f555498e66 100644 --- a/server/types/record.go +++ b/server/types/record.go @@ -18,7 +18,9 @@ import ( "github.com/dolthub/doltgresql/core/id" ) -// Record is the record type, similar to a row. +// Record is a generic, anonymous record type, without field type information supplied yet. When used with RecordExpr, +// the field type information will be created once the field expressions are analyzed and type information is available, +// and a new DoltgresType instance will be created with the field type information populated. var Record = &DoltgresType{ ID: toInternal("record"), TypLength: -1, @@ -51,5 +53,12 @@ var Record = &DoltgresType{ Acl: nil, Checks: nil, attTypMod: -1, - CompareFunc: toFuncID("btrecordcmp", toInternal("record"), toInternal("record")), + CompareFunc: toFuncID("-"), +} + +// RecordValue holds the value of a single field in a record, including type information for the +// field value. +type RecordValue struct { + Value any + Type *DoltgresType } diff --git a/server/types/type.go b/server/types/type.go index 344c943466..571db883f9 100644 --- a/server/types/type.go +++ b/server/types/type.go @@ -36,11 +36,12 @@ import ( "github.com/dolthub/doltgresql/utils" ) -// DoltgresType represents a single type. -// TODO: the serialization logic always serializes every field for built-in types, which is kind of silly. They are +// DoltgresType represents a single type. Many of these fields map directly to the type definitions in the pg_types +// system table. See https://www.postgresql.org/docs/current/catalog-pg-type.html for more information. // -// effectively hard-coded. We could serialize much more cheaply by only serializing values which can't be derived -// (for custom types) and hard-coding everything else. +// TODO: the serialization logic always serializes every field for built-in types, which is kind of silly. They are +// effectively hard-coded. We could serialize much more cheaply by only serializing values which can't be derived +// (for custom types) and hard-coding everything else. type DoltgresType struct { ID id.Type TypType TypeType @@ -55,10 +56,10 @@ type DoltgresType struct { SubscriptFunc uint32 Elem id.Type Array id.Type - InputFunc uint32 - OutputFunc uint32 - ReceiveFunc uint32 - SendFunc uint32 + InputFunc uint32 // for deserializing a text representation + OutputFunc uint32 // for serializing a text representation + ReceiveFunc uint32 // for deserializing a binary representation + SendFunc uint32 // for serializing a binary representation ModInFunc uint32 ModOutFunc uint32 AnalyzeFunc uint32 @@ -469,7 +470,7 @@ func (t *DoltgresType) IoInput(ctx *sql.Context, input string) (any, error) { func (t *DoltgresType) IoOutput(ctx *sql.Context, val any) (string, error) { var o any var err error - if t.ModInFunc != 0 || t.IsArrayType() { + if t.ModInFunc != 0 || t.IsArrayType() || t.IsCompositeType() { send := globalFunctionRegistry.GetFunction(t.OutputFunc) resolvedTypes := send.ResolvedTypes() resolvedTypes[0] = t @@ -494,6 +495,12 @@ func (t *DoltgresType) IsArrayType() bool { (t.TypCategory == TypeCategory_PseudoTypes && t.ID.TypeName() == "anyarray") } +// IsCompositeType returns true if the type is a composite type, such as an anonymous record, or a +// user-created composite type. +func (t *DoltgresType) IsCompositeType() bool { + return t.ID.TypeName() == "record" || t.TypType == TypeType_Composite +} + // IsEmptyType returns true if the type is not valid. func (t *DoltgresType) IsEmptyType() bool { return t == nil diff --git a/server/types/utils.go b/server/types/utils.go index 81fb6ee919..ec8ee32ab1 100644 --- a/server/types/utils.go +++ b/server/types/utils.go @@ -16,6 +16,7 @@ package types import ( "bytes" + "fmt" "strings" cerrors "github.com/cockroachdb/errors" @@ -152,20 +153,7 @@ func ArrToString(ctx *sql.Context, arr []any, baseType *DoltgresType, trimBool b if baseType.ID == Bool.ID && trimBool { str = string(str[0]) } - shouldQuote := false - for _, r := range str { - switch r { - case ' ', ',', '{', '}', '\\', '"': - shouldQuote = true - } - } - if shouldQuote || strings.EqualFold(str, "NULL") { - sb.WriteRune('"') - sb.WriteString(strings.ReplaceAll(str, `"`, `\"`)) - sb.WriteRune('"') - } else { - sb.WriteString(str) - } + sb.WriteString(quoteString(str)) } else { sb.WriteString("NULL") } @@ -174,6 +162,52 @@ func ArrToString(ctx *sql.Context, arr []any, baseType *DoltgresType, trimBool b return sb.String(), nil } +// RecordToString is used for the record_out function, to serialize record values for wire transfer. +// |fields| contains the values to serialize. +func RecordToString(ctx *sql.Context, fields []RecordValue) (any, error) { + sb := strings.Builder{} + sb.WriteRune('(') + for i, value := range fields { + if i > 0 { + sb.WriteString(",") + } + + if value.Value == nil { + continue + } + + str, err := value.Type.IoOutput(ctx, value.Value) + if err != nil { + return "", err + } + if value.Type.ID == Bool.ID { + str = string(str[0]) + } + + sb.WriteString(quoteString(str)) + } + sb.WriteRune(')') + + return sb.String(), nil +} + +// quoteString determines if |s| needs to be quoted, by looking for special characters like ' ' or ',', +// and if so, quotes the string and returns it. If quoting is not needed, then |s| is returned as is. +func quoteString(s string) string { + shouldQuote := false + for _, r := range s { + switch r { + case ' ', ',', '{', '}', '\\', '"': + shouldQuote = true + } + } + if shouldQuote || strings.EqualFold(s, "NULL") { + return fmt.Sprintf(`"%s"`, strings.ReplaceAll(s, `"`, `\"`)) + } else { + return s + } +} + // toInternal returns an Internal ID for the given type. This is only used for the built-in types, since they all share // the same schema (pg_catalog). func toInternal(typeName string) id.Type { diff --git a/testing/go/record_test.go b/testing/go/record_test.go new file mode 100644 index 0000000000..e689ee7855 --- /dev/null +++ b/testing/go/record_test.go @@ -0,0 +1,366 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package _go + +import ( + "testing" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" +) + +func TestRecords(t *testing.T) { + RunScripts(t, []ScriptTest{ + { + Name: "Record cannot be used as column type", + SetUpScript: []string{ + "CREATE TABLE t2 (pk INT PRIMARY KEY, c1 VARCHAR(100));", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "CREATE TABLE t (pk INT PRIMARY KEY, r RECORD);", + ExpectedErr: `column "r" has pseudo-type record`, + }, + { + Query: "ALTER TABLE t2 ADD COLUMN c2 RECORD;", + ExpectedErr: `column "c2" has pseudo-type record`, + }, + { + Query: "ALTER TABLE t2 ALTER COLUMN c1 TYPE RECORD;", + ExpectedErr: `column "c1" has pseudo-type record`, + }, + { + Query: "CREATE DOMAIN my_domain AS record;", + ExpectedErr: `"record" is not a valid base type for a domain`, + }, + { + Query: "CREATE SEQUENCE my_seq AS record;", + ExpectedErr: "sequence type must be smallint, integer, or bigint", + }, + { + Query: "CREATE TYPE outer_type AS (id int, payload record);", + ExpectedErr: `column "payload" has pseudo-type record`, + }, + }, + }, + { + Name: "Casting to record", + Assertions: []ScriptTestAssertion{ + { + Query: "select row(1, 1)::record;", + Expected: []sql.Row{{"(1,1)"}}, + }, + }, + }, + { + // TODO: Wrapping table rows with ROW() is not supported yet. Planbuilder assumes the + // table alias is a column name and not a table. + Name: "ROW() wrapping table rows", + SetUpScript: []string{ + "create table users (name text, location text, age int);", + "insert into users values ('jason', 'SEA', 42), ('max', 'SFO', 31);", + }, + Assertions: []ScriptTestAssertion{ + { + // TODO: ERROR: column "p" could not be found in any table in scope + Skip: true, + Query: "select row(p) from users p;", + Expected: []sql.Row{{`("(jason,SEA,44)")`}, {`("(max,SFO,31)")`}}, + }, + { + // TODO: ERROR: name resolution on this statement is not yet supported + Skip: true, + Query: "select row(p.*, 42) from users p;", + Expected: []sql.Row{{`(jason,SEA,42,42)`}, {`(max,SFO,31,42)`}}, + }, + { + // TODO: ERROR: (E).x is not yet supported + Skip: true, + Query: "SELECT (u).location FROM users u;", + Expected: []sql.Row{{"SEA"}, {"SFO"}}, + }, + }, + }, + { + Name: "ROW() wrapping values", + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT ROW(1, 2, 3) as myRow;", + Expected: []sql.Row{{"(1,2,3)"}}, + }, + { + Query: "SELECT (4, 5, 6) as myRow;", + Expected: []sql.Row{{"(4,5,6)"}}, + }, + { + Query: "SELECT (NULL, 'foo', NULL) as myRow;", + Expected: []sql.Row{{"(,foo,)"}}, + }, + { + Query: "SELECT (NULL, (1 > 0), 'baz') as myRow;", + Expected: []sql.Row{{"(,t,baz)"}}, + }, + }, + }, + { + Name: "ROW() equality and comparison", + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT ROW(1, 'x') = ROW(1, 'x');", + Expected: []sql.Row{{"t"}}, + }, + { + Query: "SELECT ROW(1, 'x') = ROW(1, 'y');", + Expected: []sql.Row{{"f"}}, + }, + { + Query: "SELECT ROW(1, NULL) = ROW(1, 1);", + Expected: []sql.Row{{nil}}, + }, + { + Query: "SELECT ROW(1, 2) < ROW(1, 3);", + Expected: []sql.Row{{"t"}}, + }, + { + Query: "SELECT ROW(1, 2) < ROW(2, NULL);", + Expected: []sql.Row{{"t"}}, + }, + { + Query: "SELECT ROW(2, 2) < ROW(2, NULL);", + Expected: []sql.Row{{nil}}, + }, + { + Query: "SELECT ROW(2, 2, 1) < ROW(2, NULL, 2);", + Expected: []sql.Row{{nil}}, + }, + { + Query: "SELECT ROW(1, 2) < ROW(NULL, 3);", + Expected: []sql.Row{{nil}}, + }, + { + Query: "SELECT ROW(NULL, NULL, NULL) < ROW(NULL, NULL, NULL);", + Expected: []sql.Row{{nil}}, + }, + { + Query: "SELECT ROW(1, 2) <= ROW(1, 3);", + Expected: []sql.Row{{"t"}}, + }, + { + Query: "SELECT ROW(1, 2) <= ROW(1, 2);", + Expected: []sql.Row{{"t"}}, + }, + { + Query: "SELECT ROW(1, NULL) <= ROW(1, 2);", + Expected: []sql.Row{{nil}}, + }, + { + Query: "SELECT ROW(2, 1) > ROW(1, 999);", + Expected: []sql.Row{{"t"}}, + }, + { + Query: "SELECT ROW(2, 1) > ROW(1, NULL);", + Expected: []sql.Row{{"t"}}, + }, + { + Query: "SELECT ROW(2, 1) >= ROW(1, 999);", + Expected: []sql.Row{{"t"}}, + }, + { + Query: "SELECT ROW(2, 1) >= ROW(2, 1);", + Expected: []sql.Row{{"t"}}, + }, + { + Query: "SELECT ROW(NULL, 1) >= ROW(2, 1);", + Expected: []sql.Row{{nil}}, + }, + { + Query: "SELECT ROW(1, 2) != ROW(3, 4);", + Expected: []sql.Row{{"t"}}, + }, + { + Query: "SELECT ROW(1, 2) != ROW(NULL, 4);", + Expected: []sql.Row{{"t"}}, + }, + { + Query: "SELECT ROW(NULL, 4) != ROW(NULL, 4);", + Expected: []sql.Row{{nil}}, + }, + { + // TODO: IS NOT DISTINCT FROM is not yet supported + Skip: true, + Query: "SELECT ROW(1, NULL) IS NOT DISTINCT FROM ROW(1, NULL);", + Expected: []sql.Row{{"t"}}, + }, + { + Query: "SELECT ROW(1, '2') = ROW(1, 2::TEXT);", + Expected: []sql.Row{{"t"}}, + }, + }, + }, + { + // TODO: Additional work is needed to support inserting records into tables + Skip: true, + Name: "ROW() use inserting and selecting composite rows", + SetUpScript: []string{ + "CREATE TYPE user_info AS (id INT, name TEXT, email TEXT);", + "CREATE TABLE accounts (info user_info);", + }, + Assertions: []ScriptTestAssertion{ + { + // TODO: ERROR: ASSIGNMENT_CAST: target is of type user_info but expression is of type record + Query: "INSERT INTO accounts VALUES (ROW(1, 'alice', 'a@example.com'));", + Expected: []sql.Row{{types.NewOkResult(1)}}, + }, + { + Query: "SELECT info FROM accounts;", + Expected: []sql.Row{{"(1,alice,a@example.com)"}}, + }, + { + // TODO: ERROR: (E).x is not yet supported (SQLSTATE XX000) + Query: "SELECT (a.info).name FROM accounts a;", + Expected: []sql.Row{{"alice"}}, + }, + }, + }, + { + Name: "ROW() use in WHERE clause", + SetUpScript: []string{ + "create table users (id int primary key, name text, email text);", + "insert into users values (1, 'John', 'j@a.com'), (2, 'Joe', 'joe@joe.com');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT * FROM users WHERE ROW(id, name, email) = ROW(1, 'John', 'j@a.com');", + Expected: []sql.Row{{1, "John", "j@a.com"}}, + }, + { + // TODO: IS NOT DISTINCT FROM is not yet supported + Skip: true, + Query: "SELECT * FROM users WHERE ROW(id, name) IS NOT DISTINCT FROM ROW(2, 'Jane');", + Expected: []sql.Row{{2, "Joe", "joe@joe.com"}}, + }, + }, + }, + { + Name: "ROW() casting and type inference", + Assertions: []ScriptTestAssertion{ + { + // TODO: ERROR: unknown type with oid: 2249 + Skip: true, + Query: "SELECT ROW(1, 'a')::record;", + Expected: []sql.Row{{"(1,a)"}}, + }, + { + // TODO: This does not return an error yet + Skip: true, + Query: "SELECT ROW(1, 2) = ROW(1, 'two');", + ExpectedErr: "invalid input syntax", + }, + { + // TODO: interface conversion panic + Skip: true, + Query: "SELECT ROW(1, 2) = ROW(1, '2');", + Expected: []sql.Row{{"t"}}, + }, + }, + }, + { + Name: "ROW() error cases and edge conditions", + SetUpScript: []string{ + "create table users (id int primary key, name text, email text);", + "insert into users values (1, 'John', 'j@a.com'), (2, 'Joe', 'joe@joe.com');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT ROW(1, 2) = ROW(1);", + ExpectedErr: "unequal number of entries", + }, + { + Query: "SELECT ROW(1, 2) = ROW(1, 2, 3);", + ExpectedErr: "unequal number of entries", + }, + { + Query: "SELECT ROW(1, 2) < ROW(1);", + ExpectedErr: "unequal number of entries", + }, + { + Query: "SELECT ROW(1, 2) <= ROW(1);", + ExpectedErr: "unequal number of entries", + }, + { + Query: "SELECT ROW(1, 2) > ROW(1);", + ExpectedErr: "unequal number of entries", + }, + { + Query: "SELECT ROW(1, 2) >= ROW(1);", + ExpectedErr: "unequal number of entries", + }, + { + Query: "SELECT ROW(1, 2) != ROW(1);", + ExpectedErr: "unequal number of entries", + }, + { + // TODO: expression.IsNull in GMS is used in this evaluation, but returns + // false for this case, because the record evaluates to []any{nil} + // instead of just nil. + Skip: true, + Query: "SELECT ROW(NULL) IS NULL", + Expected: []sql.Row{{"t"}}, + }, + { + // TODO: expression.IsNull in GMS is used in this evaluation, but returns + // false for this case, because the record evaluates to []any{nil} + // instead of just nil. + Skip: true, + Query: "SELECT ROW(NULL, NULL, NULL) IS NULL;", + Expected: []sql.Row{{"t"}}, + }, + { + Query: "SELECT ROW(NULL, 42, NULL) IS NULL;", + Expected: []sql.Row{{0}}, + }, + { + Query: "SELECT ROW(42) IS NULL", + Expected: []sql.Row{{0}}, + }, + { + // TODO: expression.IsNull in GMS is used in this evaluation (wrapped with + // an expression.Not), but returns true for this case, because the record + // evaluates to []any{nil} instead of just nil. + Skip: true, + Query: "SELECT ROW(NULL) IS NOT NULL;", + Expected: []sql.Row{{"f"}}, + }, + { + Query: "SELECT ROW(42) IS NOT NULL;", + Expected: []sql.Row{{"t"}}, + }, + { + Query: "SELECT ROW(id, name), COUNT(*) FROM users GROUP BY ROW(id, name);", + Expected: []sql.Row{{"(1,John)", 1}, {"(2,Joe)", 1}}, + }, + }, + }, + { + Name: "ROW() nesting", + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT ROW(ROW(1, 'x'), true);", + Expected: []sql.Row{{`("(1,x)",t)`}}, + }, + }, + }, + }) +}