Skip to content

Commit 6b7c703

Browse files
committed
Refactoring record comparison to use binary operator functions, instead of generic compare logic
1 parent 413391f commit 6b7c703

File tree

9 files changed

+171
-219
lines changed

9 files changed

+171
-219
lines changed

server/compare/utils.go

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
// Copyright 2025 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package compare
16+
17+
import (
18+
"fmt"
19+
20+
"github.com/dolthub/go-mysql-server/sql"
21+
"github.com/dolthub/go-mysql-server/sql/expression"
22+
23+
"github.com/dolthub/doltgresql/server/functions/framework"
24+
pgtypes "github.com/dolthub/doltgresql/server/types"
25+
)
26+
27+
// CompareRecords compares two record values using the specified comparison operator, |op| and returns a result
28+
// indicating if the comparison was true, false, or indeterminate (nil).
29+
func CompareRecords(ctx *sql.Context, op framework.Operator, v1 interface{}, v2 interface{}) (result any, err error) {
30+
leftRecord, rightRecord, err := checkRecordArgs(v1, v2)
31+
if err != nil {
32+
return nil, err
33+
}
34+
35+
hasNull := false
36+
hasEqualFields := true
37+
38+
for i := 0; i < len(leftRecord); i++ {
39+
typ1, ok1 := leftRecord[i].Type.(*pgtypes.DoltgresType)
40+
typ2, ok2 := rightRecord[i].Type.(*pgtypes.DoltgresType)
41+
if !ok1 {
42+
return false, fmt.Errorf("expected a DoltgresType, but got %T", leftRecord[i].Type)
43+
} else if !ok2 {
44+
return false, fmt.Errorf("expected a DoltgresType, but got %T", rightRecord[i].Type)
45+
}
46+
47+
// NULL values are by definition not comparable, so they need special handling depending
48+
// on what type of comparison we're performing.
49+
if leftRecord[i].Value == nil || rightRecord[i].Value == nil {
50+
switch op {
51+
case framework.Operator_BinaryEqual:
52+
// If we're comparing for equality, then any presence of a NULL value means
53+
// we don't have enough information to determine equality, so return nil.
54+
return nil, nil
55+
56+
case framework.Operator_BinaryLessThan, framework.Operator_BinaryGreaterThan,
57+
framework.Operator_BinaryLessOrEqual, framework.Operator_BinaryGreaterOrEqual:
58+
// If we haven't seen a prior field with non-equivalent values, then we
59+
// don't have enough certainty to make a comparison, so return nil.
60+
if hasEqualFields {
61+
return nil, nil
62+
}
63+
}
64+
65+
// Otherwise, mark that we've seen a NULL and skip over it
66+
hasNull = true
67+
continue
68+
}
69+
70+
leftLiteral := expression.NewLiteral(leftRecord[i].Value, typ1)
71+
rightLiteral := expression.NewLiteral(rightRecord[i].Value, typ2)
72+
73+
// For >= and <=, we need to distinguish between < and = (and > and =). Records
74+
// are compared by evaluating each field, in order of significance, so for >= if
75+
// the field is greater than, then we can stop comparing and return true immediately.
76+
// If the field is equal, then we need to look at the next field. For this reason,
77+
// we have to break >= and <= into separate comparisons for > or < and =.
78+
switch op {
79+
case framework.Operator_BinaryLessThan, framework.Operator_BinaryLessOrEqual:
80+
if res, err := callComparisonFunction(ctx, framework.Operator_BinaryLessThan, leftLiteral, rightLiteral); err != nil {
81+
return false, err
82+
} else if res == true {
83+
return true, nil
84+
}
85+
case framework.Operator_BinaryGreaterThan, framework.Operator_BinaryGreaterOrEqual:
86+
if res, err := callComparisonFunction(ctx, framework.Operator_BinaryGreaterThan, leftLiteral, rightLiteral); err != nil {
87+
return false, err
88+
} else if res == true {
89+
return true, nil
90+
}
91+
}
92+
93+
// After we've determined > and <, we can look at the equality comparison. For < and >, we've already returned
94+
// true if that initial comparison was true. Now we need to determine if the two fields are equal, in which case
95+
// we continue on to check the next field. If the two fields are NOT equal, then we can return false immediately.
96+
switch op {
97+
case framework.Operator_BinaryGreaterOrEqual, framework.Operator_BinaryLessOrEqual, framework.Operator_BinaryEqual:
98+
if res, err := callComparisonFunction(ctx, framework.Operator_BinaryEqual, leftLiteral, rightLiteral); err != nil {
99+
return false, err
100+
} else if res == false {
101+
return false, nil
102+
}
103+
case framework.Operator_BinaryNotEqual:
104+
if res, err := callComparisonFunction(ctx, framework.Operator_BinaryNotEqual, leftLiteral, rightLiteral); err != nil {
105+
return false, err
106+
} else if res == true {
107+
return true, nil
108+
}
109+
case framework.Operator_BinaryLessThan, framework.Operator_BinaryGreaterThan:
110+
if res, err := callComparisonFunction(ctx, framework.Operator_BinaryEqual, leftLiteral, rightLiteral); err != nil {
111+
return false, err
112+
} else if res == false {
113+
// For < and >, we still need to check equality to know if we need to continue checking additional fields.
114+
// If
115+
hasEqualFields = false
116+
}
117+
default:
118+
return false, fmt.Errorf("unsupportd binary operator: %s", op)
119+
}
120+
}
121+
122+
// If the records contain any NULL fields, but all non-NULL fields are equal, then we
123+
// don't have enough certainty to return a result.
124+
if hasNull && hasEqualFields {
125+
return nil, nil
126+
}
127+
128+
return true, nil
129+
}
130+
131+
// checkRecordArgs asserts that |v1| and |v2| are both []pgtypes.RecordValue, and that they have the same number of
132+
// elements, then returns them. If any problems were detected, an error is returnd instead.
133+
func checkRecordArgs(v1, v2 interface{}) (leftRecord, rightRecord []pgtypes.RecordValue, err error) {
134+
leftRecord, ok1 := v1.([]pgtypes.RecordValue)
135+
rightRecord, ok2 := v2.([]pgtypes.RecordValue)
136+
if !ok1 {
137+
return nil, nil, fmt.Errorf("expected a RecordValue, but got %T", v1)
138+
} else if !ok2 {
139+
return nil, nil, fmt.Errorf("expected a RecordValue, but got %T", v2)
140+
}
141+
142+
if len(leftRecord) != len(rightRecord) {
143+
return nil, nil, fmt.Errorf("unequal number of entries in row expressions")
144+
}
145+
146+
return leftRecord, rightRecord, nil
147+
}
148+
149+
// callComparisonFunction invokes the binary comparison function for the specified operator |op| with the two arguments
150+
// |leftLiteral| and |rightLiteral|. The result and any error are returned.
151+
func callComparisonFunction(ctx *sql.Context, op framework.Operator, leftLiteral, rightLiteral sql.Expression) (result any, err error) {
152+
intermediateFunction := framework.GetBinaryFunction(op)
153+
compiledFunction := intermediateFunction.Compile("record-cmp", leftLiteral, rightLiteral)
154+
return compiledFunction.Eval(ctx, nil)
155+
}

server/functions/binary/equal.go

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/dolthub/doltgresql/core/id"
2424
"github.com/dolthub/doltgresql/postgres/parser/duration"
2525
"github.com/dolthub/doltgresql/postgres/parser/uuid"
26+
"github.com/dolthub/doltgresql/server/compare"
2627
"github.com/dolthub/doltgresql/server/functions/framework"
2728
pgtypes "github.com/dolthub/doltgresql/server/types"
2829
)
@@ -431,15 +432,7 @@ var record_eq = framework.Function2{
431432
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Record, pgtypes.Record},
432433
Strict: true,
433434
Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
434-
if err := pgtypes.ValidateEqualRecordFieldCount(val1, val2); err != nil {
435-
return nil, err
436-
}
437-
// tuples can only be compared for equality if there are no NULL values
438-
if pgtypes.RecordValueHasNull(val1) || pgtypes.RecordValueHasNull(val2) {
439-
return nil, nil
440-
}
441-
res, err := pgtypes.CompareRecords(ctx, val1, val2)
442-
return res == 0, err
435+
return compare.CompareRecords(ctx, framework.Operator_BinaryEqual, val1, val2)
443436
},
444437
}
445438

server/functions/binary/greater.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/dolthub/doltgresql/core/id"
2525
"github.com/dolthub/doltgresql/postgres/parser/duration"
2626
"github.com/dolthub/doltgresql/postgres/parser/uuid"
27+
"github.com/dolthub/doltgresql/server/compare"
2728
"github.com/dolthub/doltgresql/server/functions/framework"
2829
pgtypes "github.com/dolthub/doltgresql/server/types"
2930
)
@@ -525,14 +526,7 @@ var record_gt = framework.Function2{
525526
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Record, pgtypes.Record},
526527
Strict: true,
527528
Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
528-
if err := pgtypes.ValidateEqualRecordFieldCount(val1, val2); err != nil {
529-
return nil, err
530-
}
531-
if !pgtypes.CanCompareRecordValues(val1, val2) {
532-
return nil, nil
533-
}
534-
res, err := pgtypes.CompareRecords(ctx, val1, val2)
535-
return res == 1, err
529+
return compare.CompareRecords(ctx, framework.Operator_BinaryGreaterThan, val1, val2)
536530
},
537531
}
538532

server/functions/binary/greater_equal.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/dolthub/doltgresql/core/id"
2525
"github.com/dolthub/doltgresql/postgres/parser/duration"
2626
"github.com/dolthub/doltgresql/postgres/parser/uuid"
27+
"github.com/dolthub/doltgresql/server/compare"
2728
"github.com/dolthub/doltgresql/server/functions/framework"
2829
pgtypes "github.com/dolthub/doltgresql/server/types"
2930
)
@@ -525,14 +526,7 @@ var record_ge = framework.Function2{
525526
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Record, pgtypes.Record},
526527
Strict: true,
527528
Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
528-
if err := pgtypes.ValidateEqualRecordFieldCount(val1, val2); err != nil {
529-
return nil, err
530-
}
531-
if !pgtypes.CanCompareRecordValues(val1, val2) {
532-
return nil, nil
533-
}
534-
res, err := pgtypes.CompareRecords(ctx, val1, val2)
535-
return res >= 0, err
529+
return compare.CompareRecords(ctx, framework.Operator_BinaryGreaterOrEqual, val1, val2)
536530
},
537531
}
538532

server/functions/binary/less.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/dolthub/doltgresql/core/id"
2525
"github.com/dolthub/doltgresql/postgres/parser/duration"
2626
"github.com/dolthub/doltgresql/postgres/parser/uuid"
27+
"github.com/dolthub/doltgresql/server/compare"
2728
"github.com/dolthub/doltgresql/server/functions/framework"
2829
pgtypes "github.com/dolthub/doltgresql/server/types"
2930
)
@@ -525,14 +526,7 @@ var record_lt = framework.Function2{
525526
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Record, pgtypes.Record},
526527
Strict: true,
527528
Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
528-
if err := pgtypes.ValidateEqualRecordFieldCount(val1, val2); err != nil {
529-
return nil, err
530-
}
531-
if !pgtypes.CanCompareRecordValues(val1, val2) {
532-
return nil, nil
533-
}
534-
res, err := pgtypes.CompareRecords(ctx, val1, val2)
535-
return res == -1, err
529+
return compare.CompareRecords(ctx, framework.Operator_BinaryLessThan, val1, val2)
536530
},
537531
}
538532

server/functions/binary/less_equal.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/dolthub/doltgresql/core/id"
2525
"github.com/dolthub/doltgresql/postgres/parser/duration"
2626
"github.com/dolthub/doltgresql/postgres/parser/uuid"
27+
"github.com/dolthub/doltgresql/server/compare"
2728
"github.com/dolthub/doltgresql/server/functions/framework"
2829
pgtypes "github.com/dolthub/doltgresql/server/types"
2930
)
@@ -525,14 +526,7 @@ var record_le = framework.Function2{
525526
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Record, pgtypes.Record},
526527
Strict: true,
527528
Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
528-
if err := pgtypes.ValidateEqualRecordFieldCount(val1, val2); err != nil {
529-
return nil, err
530-
}
531-
if !pgtypes.CanCompareRecordValues(val1, val2) {
532-
return nil, nil
533-
}
534-
res, err := pgtypes.CompareRecords(ctx, val1, val2)
535-
return res <= 0, err
529+
return compare.CompareRecords(ctx, framework.Operator_BinaryLessThan, val1, val2)
536530
},
537531
}
538532

server/functions/binary/not_equal.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/dolthub/doltgresql/core/id"
2424
"github.com/dolthub/doltgresql/postgres/parser/duration"
2525
"github.com/dolthub/doltgresql/postgres/parser/uuid"
26+
"github.com/dolthub/doltgresql/server/compare"
2627
"github.com/dolthub/doltgresql/server/functions/framework"
2728
pgtypes "github.com/dolthub/doltgresql/server/types"
2829
)
@@ -526,14 +527,7 @@ var record_ne = framework.Function2{
526527
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Record, pgtypes.Record},
527528
Strict: true,
528529
Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
529-
if err := pgtypes.ValidateEqualRecordFieldCount(val1, val2); err != nil {
530-
return nil, err
531-
}
532-
if !pgtypes.CanCompareRecordValuesForNotEquals(val1, val2) {
533-
return nil, nil
534-
}
535-
res, err := pgtypes.CompareRecords(ctx, val1, val2)
536-
return res != 0, err
530+
return compare.CompareRecords(ctx, framework.Operator_BinaryNotEqual, val1, val2)
537531
},
538532
}
539533

0 commit comments

Comments
 (0)