|
| 1 | +// Copyright 2025 Dolthub, Inc. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +package compare |
| 16 | + |
| 17 | +import ( |
| 18 | + "fmt" |
| 19 | + |
| 20 | + "github.com/dolthub/go-mysql-server/sql" |
| 21 | + "github.com/dolthub/go-mysql-server/sql/expression" |
| 22 | + |
| 23 | + "github.com/dolthub/doltgresql/server/functions/framework" |
| 24 | + pgtypes "github.com/dolthub/doltgresql/server/types" |
| 25 | +) |
| 26 | + |
| 27 | +// CompareRecords compares two record values using the specified comparison operator, |op| and returns a result |
| 28 | +// indicating if the comparison was true, false, or indeterminate (nil). |
| 29 | +// |
| 30 | +// More info on rules for comparing records: |
| 31 | +// https://www.postgresql.org/docs/current/functions-comparisons.html#ROW-WISE-COMPARISON |
| 32 | +func CompareRecords(ctx *sql.Context, op framework.Operator, v1 interface{}, v2 interface{}) (result any, err error) { |
| 33 | + leftRecord, rightRecord, err := checkRecordArgs(v1, v2) |
| 34 | + if err != nil { |
| 35 | + return nil, err |
| 36 | + } |
| 37 | + |
| 38 | + hasNull := false |
| 39 | + hasEqualFields := true |
| 40 | + |
| 41 | + for i := 0; i < len(leftRecord); i++ { |
| 42 | + typ1 := leftRecord[i].Type |
| 43 | + typ2 := rightRecord[i].Type |
| 44 | + |
| 45 | + // NULL values are by definition not comparable, so they need special handling depending |
| 46 | + // on what type of comparison we're performing. |
| 47 | + if leftRecord[i].Value == nil || rightRecord[i].Value == nil { |
| 48 | + switch op { |
| 49 | + case framework.Operator_BinaryEqual: |
| 50 | + // If we're comparing for equality, then any presence of a NULL value means |
| 51 | + // we don't have enough information to determine equality, so return nil. |
| 52 | + return nil, nil |
| 53 | + |
| 54 | + case framework.Operator_BinaryLessThan, framework.Operator_BinaryGreaterThan, |
| 55 | + framework.Operator_BinaryLessOrEqual, framework.Operator_BinaryGreaterOrEqual: |
| 56 | + // If we haven't seen a prior field with non-equivalent values, then we |
| 57 | + // don't have enough certainty to make a comparison, so return nil. |
| 58 | + if hasEqualFields { |
| 59 | + return nil, nil |
| 60 | + } |
| 61 | + } |
| 62 | + |
| 63 | + // Otherwise, mark that we've seen a NULL and skip over it |
| 64 | + hasNull = true |
| 65 | + continue |
| 66 | + } |
| 67 | + |
| 68 | + leftLiteral := expression.NewLiteral(leftRecord[i].Value, typ1) |
| 69 | + rightLiteral := expression.NewLiteral(rightRecord[i].Value, typ2) |
| 70 | + |
| 71 | + // For >= and <=, we need to distinguish between < and = (and > and =). Records |
| 72 | + // are compared by evaluating each field, in order of significance, so for >= if |
| 73 | + // the field is greater than, then we can stop comparing and return true immediately. |
| 74 | + // If the field is equal, then we need to look at the next field. For this reason, |
| 75 | + // we have to break >= and <= into separate comparisons for > or < and =. |
| 76 | + switch op { |
| 77 | + case framework.Operator_BinaryLessThan, framework.Operator_BinaryLessOrEqual: |
| 78 | + if res, err := callComparisonFunction(ctx, framework.Operator_BinaryLessThan, leftLiteral, rightLiteral); err != nil { |
| 79 | + return false, err |
| 80 | + } else if res == true { |
| 81 | + return true, nil |
| 82 | + } |
| 83 | + case framework.Operator_BinaryGreaterThan, framework.Operator_BinaryGreaterOrEqual: |
| 84 | + if res, err := callComparisonFunction(ctx, framework.Operator_BinaryGreaterThan, leftLiteral, rightLiteral); err != nil { |
| 85 | + return false, err |
| 86 | + } else if res == true { |
| 87 | + return true, nil |
| 88 | + } |
| 89 | + } |
| 90 | + |
| 91 | + // After we've determined > and <, we can look at the equality comparison. For < and >, we've already returned |
| 92 | + // true if that initial comparison was true. Now we need to determine if the two fields are equal, in which case |
| 93 | + // we continue on to check the next field. If the two fields are NOT equal, then we can return false immediately. |
| 94 | + switch op { |
| 95 | + case framework.Operator_BinaryGreaterOrEqual, framework.Operator_BinaryLessOrEqual, framework.Operator_BinaryEqual: |
| 96 | + if res, err := callComparisonFunction(ctx, framework.Operator_BinaryEqual, leftLiteral, rightLiteral); err != nil { |
| 97 | + return false, err |
| 98 | + } else if res == false { |
| 99 | + return false, nil |
| 100 | + } |
| 101 | + case framework.Operator_BinaryNotEqual: |
| 102 | + if res, err := callComparisonFunction(ctx, framework.Operator_BinaryNotEqual, leftLiteral, rightLiteral); err != nil { |
| 103 | + return false, err |
| 104 | + } else if res == true { |
| 105 | + return true, nil |
| 106 | + } |
| 107 | + case framework.Operator_BinaryLessThan, framework.Operator_BinaryGreaterThan: |
| 108 | + if res, err := callComparisonFunction(ctx, framework.Operator_BinaryEqual, leftLiteral, rightLiteral); err != nil { |
| 109 | + return false, err |
| 110 | + } else if res == false { |
| 111 | + // For < and >, we still need to check equality to know if we need to continue checking additional fields. |
| 112 | + // If |
| 113 | + hasEqualFields = false |
| 114 | + } |
| 115 | + default: |
| 116 | + return false, fmt.Errorf("unsupportd binary operator: %s", op) |
| 117 | + } |
| 118 | + } |
| 119 | + |
| 120 | + // If the records contain any NULL fields, but all non-NULL fields are equal, then we |
| 121 | + // don't have enough certainty to return a result. |
| 122 | + if hasNull && hasEqualFields { |
| 123 | + return nil, nil |
| 124 | + } |
| 125 | + |
| 126 | + return true, nil |
| 127 | +} |
| 128 | + |
| 129 | +// checkRecordArgs asserts that |v1| and |v2| are both []pgtypes.RecordValue, and that they have the same number of |
| 130 | +// elements, then returns them. If any problems were detected, an error is returnd instead. |
| 131 | +func checkRecordArgs(v1, v2 interface{}) (leftRecord, rightRecord []pgtypes.RecordValue, err error) { |
| 132 | + leftRecord, ok1 := v1.([]pgtypes.RecordValue) |
| 133 | + rightRecord, ok2 := v2.([]pgtypes.RecordValue) |
| 134 | + if !ok1 { |
| 135 | + return nil, nil, fmt.Errorf("expected a RecordValue, but got %T", v1) |
| 136 | + } else if !ok2 { |
| 137 | + return nil, nil, fmt.Errorf("expected a RecordValue, but got %T", v2) |
| 138 | + } |
| 139 | + |
| 140 | + if len(leftRecord) != len(rightRecord) { |
| 141 | + return nil, nil, fmt.Errorf("unequal number of entries in row expressions") |
| 142 | + } |
| 143 | + |
| 144 | + return leftRecord, rightRecord, nil |
| 145 | +} |
| 146 | + |
| 147 | +// callComparisonFunction invokes the binary comparison function for the specified operator |op| with the two arguments |
| 148 | +// |leftLiteral| and |rightLiteral|. The result and any error are returned. |
| 149 | +func callComparisonFunction(ctx *sql.Context, op framework.Operator, leftLiteral, rightLiteral sql.Expression) (result any, err error) { |
| 150 | + intermediateFunction := framework.GetBinaryFunction(op) |
| 151 | + compiledFunction := intermediateFunction.Compile( |
| 152 | + "_internal_record_comparison_function", leftLiteral, rightLiteral) |
| 153 | + return compiledFunction.Eval(ctx, nil) |
| 154 | +} |
0 commit comments