Skip to content

Commit 252cebd

Browse files
committed
Extending the record type to support basic ROW() constructor uses
1 parent 17981fd commit 252cebd

File tree

20 files changed

+851
-56
lines changed

20 files changed

+851
-56
lines changed

server/ast/alter_table.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"github.com/sirupsen/logrus"
2323

2424
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
25+
pgtypes "github.com/dolthub/doltgresql/server/types"
2526
)
2627

2728
// nodeAlterTable handles *tree.AlterTable nodes.
@@ -318,6 +319,10 @@ func nodeAlterTableAlterColumnType(ctx *Context, node *tree.AlterTableAlterColum
318319
return nil, err
319320
}
320321

322+
if resolvedType == pgtypes.Record {
323+
return nil, errors.Errorf(`column "%s" has pseudo-type record`, node.Column.String())
324+
}
325+
321326
return &vitess.DDL{
322327
Action: "alter",
323328
Table: tableName,

server/ast/column_table_def.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ func nodeColumnTableDef(ctx *Context, node *tree.ColumnTableDef) (*vitess.Column
4242
return nil, err
4343
}
4444

45+
if resolvedType == pgtypes.Record {
46+
return nil, errors.Errorf(`column "%s" has pseudo-type record`, node.Name.String())
47+
}
48+
4549
var isNull vitess.BoolVal
4650
var isNotNull vitess.BoolVal
4751
switch node.Nullable.Nullability {

server/ast/create_domain.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222

2323
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
2424
pgnodes "github.com/dolthub/doltgresql/server/node"
25+
pgtypes "github.com/dolthub/doltgresql/server/types"
2526
)
2627

2728
// nodeCreateDomain handles *tree.CreateDomain nodes.
@@ -37,6 +38,11 @@ func nodeCreateDomain(ctx *Context, node *tree.CreateDomain) (vitess.Statement,
3738
if err != nil {
3839
return nil, err
3940
}
41+
42+
if dataType == pgtypes.Record {
43+
return nil, errors.Errorf(`"record" is not a valid base type for a domain`)
44+
}
45+
4046
if node.Collate != "" {
4147
return nil, errors.Errorf("domain collation is not yet supported")
4248
}

server/ast/create_type.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121

2222
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
2323
pgnodes "github.com/dolthub/doltgresql/server/node"
24+
pgtypes "github.com/dolthub/doltgresql/server/types"
2425
)
2526

2627
// nodeCreateType handles *tree.CreateType nodes.
@@ -43,6 +44,11 @@ func nodeCreateType(ctx *Context, node *tree.CreateType) (vitess.Statement, erro
4344
if err != nil {
4445
return nil, err
4546
}
47+
48+
if dataType == pgtypes.Record {
49+
return nil, errors.Errorf(`column "%s" has pseudo-type record`, t.AttrName)
50+
}
51+
4652
typs[i] = pgnodes.CompositeAsType{
4753
AttrName: t.AttrName,
4854
Typ: dataType,

server/ast/expr.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -373,9 +373,15 @@ func nodeExpr(ctx *Context, node tree.Expr) (vitess.Expr, error) {
373373
Expression: pgexprs.NewInSubquery(),
374374
Children: vitess.Exprs{left, right},
375375
}, nil
376-
default:
377-
return nil, errors.Errorf("right side of IN expression is not a tuple or subquery, got %T", right)
376+
case vitess.InjectedExpr:
377+
if _, ok := right.Expression.(*pgexprs.RecordExpr); ok {
378+
return vitess.InjectedExpr{
379+
Expression: pgexprs.NewInTuple(),
380+
Children: vitess.Exprs{left, vitess.ValTuple(right.Children)},
381+
}, nil
382+
}
378383
}
384+
return nil, errors.Errorf("right side of IN expression is not a tuple or subquery, got %T", right)
379385
case tree.NotIn:
380386
innerExpr := vitess.InjectedExpr{
381387
Expression: pgexprs.NewInTuple(),
@@ -776,15 +782,16 @@ func nodeExpr(ctx *Context, node tree.Expr) (vitess.Expr, error) {
776782
if len(node.Labels) > 0 {
777783
return nil, errors.Errorf("tuple labels are not yet supported")
778784
}
779-
if node.Row {
780-
return nil, errors.Errorf("ROW keyword for tuples not yet supported")
781-
}
782785

783786
valTuple, err := nodeExprs(ctx, node.Exprs)
784787
if err != nil {
785788
return nil, err
786789
}
787-
return vitess.ValTuple(valTuple), nil
790+
791+
return vitess.InjectedExpr{
792+
Expression: pgexprs.NewRecordExpr(),
793+
Children: valTuple,
794+
}, nil
788795
case *tree.TupleStar:
789796
return nil, errors.Errorf("(E).* is not yet supported")
790797
case *tree.UnaryExpr:

server/ast/resolvable_type_reference.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ func nodeResolvableTypeReference(ctx *Context, typ tree.ResolvableTypeReference)
7272
return nil, nil, errors.Errorf("geography types are not yet supported")
7373
} else {
7474
switch columnType.Oid() {
75+
case oid.T_record:
76+
resolvedType = pgtypes.Record
7577
case oid.T_bool:
7678
resolvedType = pgtypes.Bool
7779
case oid.T_bytea:

server/compare/utils.go

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Copyright 2025 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package compare
16+
17+
import (
18+
"fmt"
19+
20+
"github.com/dolthub/go-mysql-server/sql"
21+
"github.com/dolthub/go-mysql-server/sql/expression"
22+
23+
"github.com/dolthub/doltgresql/server/functions/framework"
24+
pgtypes "github.com/dolthub/doltgresql/server/types"
25+
)
26+
27+
// CompareRecords compares two record values using the specified comparison operator, |op| and returns a result
28+
// indicating if the comparison was true, false, or indeterminate (nil).
29+
//
30+
// More info on rules for comparing records:
31+
// https://www.postgresql.org/docs/current/functions-comparisons.html#ROW-WISE-COMPARISON
32+
func CompareRecords(ctx *sql.Context, op framework.Operator, v1 interface{}, v2 interface{}) (result any, err error) {
33+
leftRecord, rightRecord, err := checkRecordArgs(v1, v2)
34+
if err != nil {
35+
return nil, err
36+
}
37+
38+
hasNull := false
39+
hasEqualFields := true
40+
41+
for i := 0; i < len(leftRecord); i++ {
42+
typ1 := leftRecord[i].Type
43+
typ2 := rightRecord[i].Type
44+
45+
// NULL values are by definition not comparable, so they need special handling depending
46+
// on what type of comparison we're performing.
47+
if leftRecord[i].Value == nil || rightRecord[i].Value == nil {
48+
switch op {
49+
case framework.Operator_BinaryEqual:
50+
// If we're comparing for equality, then any presence of a NULL value means
51+
// we don't have enough information to determine equality, so return nil.
52+
return nil, nil
53+
54+
case framework.Operator_BinaryLessThan, framework.Operator_BinaryGreaterThan,
55+
framework.Operator_BinaryLessOrEqual, framework.Operator_BinaryGreaterOrEqual:
56+
// If we haven't seen a prior field with non-equivalent values, then we
57+
// don't have enough certainty to make a comparison, so return nil.
58+
if hasEqualFields {
59+
return nil, nil
60+
}
61+
}
62+
63+
// Otherwise, mark that we've seen a NULL and skip over it
64+
hasNull = true
65+
continue
66+
}
67+
68+
leftLiteral := expression.NewLiteral(leftRecord[i].Value, typ1)
69+
rightLiteral := expression.NewLiteral(rightRecord[i].Value, typ2)
70+
71+
// For >= and <=, we need to distinguish between < and = (and > and =). Records
72+
// are compared by evaluating each field, in order of significance, so for >= if
73+
// the field is greater than, then we can stop comparing and return true immediately.
74+
// If the field is equal, then we need to look at the next field. For this reason,
75+
// we have to break >= and <= into separate comparisons for > or < and =.
76+
switch op {
77+
case framework.Operator_BinaryLessThan, framework.Operator_BinaryLessOrEqual:
78+
if res, err := callComparisonFunction(ctx, framework.Operator_BinaryLessThan, leftLiteral, rightLiteral); err != nil {
79+
return false, err
80+
} else if res == true {
81+
return true, nil
82+
}
83+
case framework.Operator_BinaryGreaterThan, framework.Operator_BinaryGreaterOrEqual:
84+
if res, err := callComparisonFunction(ctx, framework.Operator_BinaryGreaterThan, leftLiteral, rightLiteral); err != nil {
85+
return false, err
86+
} else if res == true {
87+
return true, nil
88+
}
89+
}
90+
91+
// After we've determined > and <, we can look at the equality comparison. For < and >, we've already returned
92+
// true if that initial comparison was true. Now we need to determine if the two fields are equal, in which case
93+
// we continue on to check the next field. If the two fields are NOT equal, then we can return false immediately.
94+
switch op {
95+
case framework.Operator_BinaryGreaterOrEqual, framework.Operator_BinaryLessOrEqual, framework.Operator_BinaryEqual:
96+
if res, err := callComparisonFunction(ctx, framework.Operator_BinaryEqual, leftLiteral, rightLiteral); err != nil {
97+
return false, err
98+
} else if res == false {
99+
return false, nil
100+
}
101+
case framework.Operator_BinaryNotEqual:
102+
if res, err := callComparisonFunction(ctx, framework.Operator_BinaryNotEqual, leftLiteral, rightLiteral); err != nil {
103+
return false, err
104+
} else if res == true {
105+
return true, nil
106+
}
107+
case framework.Operator_BinaryLessThan, framework.Operator_BinaryGreaterThan:
108+
if res, err := callComparisonFunction(ctx, framework.Operator_BinaryEqual, leftLiteral, rightLiteral); err != nil {
109+
return false, err
110+
} else if res == false {
111+
// For < and >, we still need to check equality to know if we need to continue checking additional fields.
112+
// If
113+
hasEqualFields = false
114+
}
115+
default:
116+
return false, fmt.Errorf("unsupportd binary operator: %s", op)
117+
}
118+
}
119+
120+
// If the records contain any NULL fields, but all non-NULL fields are equal, then we
121+
// don't have enough certainty to return a result.
122+
if hasNull && hasEqualFields {
123+
return nil, nil
124+
}
125+
126+
return true, nil
127+
}
128+
129+
// checkRecordArgs asserts that |v1| and |v2| are both []pgtypes.RecordValue, and that they have the same number of
130+
// elements, then returns them. If any problems were detected, an error is returnd instead.
131+
func checkRecordArgs(v1, v2 interface{}) (leftRecord, rightRecord []pgtypes.RecordValue, err error) {
132+
leftRecord, ok1 := v1.([]pgtypes.RecordValue)
133+
rightRecord, ok2 := v2.([]pgtypes.RecordValue)
134+
if !ok1 {
135+
return nil, nil, fmt.Errorf("expected a RecordValue, but got %T", v1)
136+
} else if !ok2 {
137+
return nil, nil, fmt.Errorf("expected a RecordValue, but got %T", v2)
138+
}
139+
140+
if len(leftRecord) != len(rightRecord) {
141+
return nil, nil, fmt.Errorf("unequal number of entries in row expressions")
142+
}
143+
144+
return leftRecord, rightRecord, nil
145+
}
146+
147+
// callComparisonFunction invokes the binary comparison function for the specified operator |op| with the two arguments
148+
// |leftLiteral| and |rightLiteral|. The result and any error are returned.
149+
func callComparisonFunction(ctx *sql.Context, op framework.Operator, leftLiteral, rightLiteral sql.Expression) (result any, err error) {
150+
intermediateFunction := framework.GetBinaryFunction(op)
151+
compiledFunction := intermediateFunction.Compile(
152+
"_internal_record_comparison_function", leftLiteral, rightLiteral)
153+
return compiledFunction.Eval(ctx, nil)
154+
}

server/expression/in_tuple.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,11 +242,18 @@ func (it *InTuple) WithResolvedChildren(children []any) (any, error) {
242242
if !ok {
243243
return nil, errors.Errorf("expected vitess child to be an expression but has type `%T`", children[0])
244244
}
245-
right, ok := children[1].(expression.Tuple)
246-
if !ok {
247-
return nil, errors.Errorf("expected vitess child to be an expression tuple but has type `%T`", children[1])
245+
246+
switch right := children[1].(type) {
247+
case expression.Tuple:
248+
return it.WithChildren(left, right)
249+
case *RecordExpr:
250+
// TODO: For now, if we see a RecordExpr come in, we convert it to a vitess Tuple representation, so that
251+
// the existing in_tuple code can work with it. Alternatively, we could change in_tuple to always
252+
// work directly with a Record expression.
253+
return it.WithChildren(left, expression.Tuple(right.exprs))
254+
default:
255+
return nil, errors.Errorf("expected child to be a RecordExpr or vitess Tuple but has type `%T`", children[1])
248256
}
249-
return it.WithChildren(left, right)
250257
}
251258

252259
// Left implements the expression.BinaryExpression interface.

0 commit comments

Comments
 (0)