Skip to content

Commit e2a7f7d

Browse files
pramodsatyameta-codesync[bot]
authored andcommitted
feat: Add variant comparison with null handling mode (#15726)
Summary: Enables comparison of variants, with `Variant::equals` API, to use custom null handling mode. Required for constant expression comparison using `null-as-indeterminate` null comparison semantics (#15705). Pull Request resolved: #15726 Reviewed By: pratikpugalia Differential Revision: D95303918 Pulled By: peterenescu fbshipit-source-id: abfda1110116209f85517e2781120768b85c0f38
1 parent d807c6e commit e2a7f7d

File tree

3 files changed

+266
-59
lines changed

3 files changed

+266
-59
lines changed

velox/type/Variant.cpp

Lines changed: 98 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,19 @@
2525
namespace facebook::velox {
2626
namespace {
2727

28-
bool dispatchDynamicVariantEquality(
28+
std::optional<bool> dispatchDynamicVariantEquality(
2929
const Variant& a,
3030
const Variant& b,
31-
const bool& enableNullEqualsNull);
31+
CompareFlags::NullHandlingMode nullHandlingMode);
3232

33-
template <bool nullEqualsNull>
34-
bool evaluateNullEquality(const Variant& a, const Variant& b) {
35-
if constexpr (nullEqualsNull) {
36-
if (a.isNull() && b.isNull()) {
37-
return true;
38-
}
33+
std::optional<bool> evaluateNullEquality(
34+
const Variant& a,
35+
const Variant& b,
36+
CompareFlags::NullHandlingMode nullHandlingMode) {
37+
if (nullHandlingMode == CompareFlags::NullHandlingMode::kNullAsValue) {
38+
return a.isNull() && b.isNull();
3939
}
40-
return false;
40+
return std::nullopt;
4141
}
4242

4343
template <TypeKind KIND>
@@ -46,10 +46,10 @@ struct VariantEquality;
4646
// scalars
4747
template <TypeKind KIND>
4848
struct VariantEquality {
49-
template <bool NullEqualsNull>
50-
static bool equals(const Variant& a, const Variant& b) {
49+
template <CompareFlags::NullHandlingMode nullHandlingMode>
50+
static std::optional<bool> equals(const Variant& a, const Variant& b) {
5151
if (a.isNull() || b.isNull()) {
52-
return evaluateNullEquality<NullEqualsNull>(a, b);
52+
return evaluateNullEquality(a, b, nullHandlingMode);
5353
}
5454
return a.value<KIND>() == b.value<KIND>();
5555
}
@@ -58,10 +58,10 @@ struct VariantEquality {
5858
// timestamp
5959
template <>
6060
struct VariantEquality<TypeKind::TIMESTAMP> {
61-
template <bool NullEqualsNull>
62-
static bool equals(const Variant& a, const Variant& b) {
61+
template <CompareFlags::NullHandlingMode nullHandlingMode>
62+
static std::optional<bool> equals(const Variant& a, const Variant& b) {
6363
if (a.isNull() || b.isNull()) {
64-
return evaluateNullEquality<NullEqualsNull>(a, b);
64+
return evaluateNullEquality(a, b, nullHandlingMode);
6565
} else {
6666
return a.value<TypeKind::TIMESTAMP>() == b.value<TypeKind::TIMESTAMP>();
6767
}
@@ -71,34 +71,42 @@ struct VariantEquality<TypeKind::TIMESTAMP> {
7171
// array
7272
template <>
7373
struct VariantEquality<TypeKind::ARRAY> {
74-
template <bool NullEqualsNull>
75-
static bool equals(const Variant& a, const Variant& b) {
74+
template <CompareFlags::NullHandlingMode nullHandlingMode>
75+
static std::optional<bool> equals(const Variant& a, const Variant& b) {
7676
if (a.isNull() || b.isNull()) {
77-
return evaluateNullEquality<NullEqualsNull>(a, b);
77+
return evaluateNullEquality(a, b, nullHandlingMode);
7878
}
7979
auto& aArray = a.value<TypeKind::ARRAY>();
8080
auto& bArray = b.value<TypeKind::ARRAY>();
8181
if (aArray.size() != bArray.size()) {
8282
return false;
8383
}
84+
bool isComparisonIndeterminate = false;
8485
for (size_t i = 0; i != aArray.size(); ++i) {
8586
// todo(youknowjack): switch outside the loop
86-
bool result =
87-
dispatchDynamicVariantEquality(aArray[i], bArray[i], NullEqualsNull);
88-
if (!result) {
89-
return false;
87+
auto compareResult = dispatchDynamicVariantEquality(
88+
aArray[i], bArray[i], nullHandlingMode);
89+
if (compareResult.has_value()) {
90+
if (!compareResult.value()) {
91+
return false;
92+
}
93+
} else {
94+
isComparisonIndeterminate = true;
9095
}
9196
}
97+
if (isComparisonIndeterminate) {
98+
return std::nullopt;
99+
}
92100
return true;
93101
}
94102
};
95103

96104
template <>
97105
struct VariantEquality<TypeKind::ROW> {
98-
template <bool NullEqualsNull>
99-
static bool equals(const Variant& a, const Variant& b) {
106+
template <CompareFlags::NullHandlingMode nullHandlingMode>
107+
static std::optional<bool> equals(const Variant& a, const Variant& b) {
100108
if (a.isNull() || b.isNull()) {
101-
return evaluateNullEquality<NullEqualsNull>(a, b);
109+
return evaluateNullEquality(a, b, nullHandlingMode);
102110
}
103111
auto& aRow = a.value<TypeKind::ROW>();
104112
auto& bRow = b.value<TypeKind::ROW>();
@@ -108,23 +116,31 @@ struct VariantEquality<TypeKind::ROW> {
108116
return false;
109117
}
110118
// compare array values
119+
bool isComparisonIndeterminate = false;
111120
for (size_t i = 0; i != aRow.size(); ++i) {
112-
bool result =
113-
dispatchDynamicVariantEquality(aRow[i], bRow[i], NullEqualsNull);
114-
if (!result) {
115-
return false;
121+
auto compareResult =
122+
dispatchDynamicVariantEquality(aRow[i], bRow[i], nullHandlingMode);
123+
if (compareResult.has_value()) {
124+
if (!compareResult.value()) {
125+
return false;
126+
}
127+
} else {
128+
isComparisonIndeterminate = true;
116129
}
117130
}
131+
if (isComparisonIndeterminate) {
132+
return std::nullopt;
133+
}
118134
return true;
119135
}
120136
};
121137

122138
template <>
123139
struct VariantEquality<TypeKind::MAP> {
124-
template <bool NullEqualsNull>
125-
static bool equals(const Variant& a, const Variant& b) {
140+
template <CompareFlags::NullHandlingMode nullHandlingMode>
141+
static std::optional<bool> equals(const Variant& a, const Variant& b) {
126142
if (a.isNull() || b.isNull()) {
127-
return evaluateNullEquality<NullEqualsNull>(a, b);
143+
return evaluateNullEquality(a, b, nullHandlingMode);
128144
}
129145

130146
auto& aMap = a.value<TypeKind::MAP>();
@@ -134,32 +150,55 @@ struct VariantEquality<TypeKind::MAP> {
134150
return false;
135151
}
136152
// compare map values
153+
bool isComparisonIndeterminate = false;
137154
for (auto it_a = aMap.begin(), it_b = bMap.begin();
138155
it_a != aMap.end() && it_b != bMap.end();
139156
++it_a, ++it_b) {
140-
if (dispatchDynamicVariantEquality(
141-
it_a->first, it_b->first, NullEqualsNull) &&
142-
dispatchDynamicVariantEquality(
143-
it_a->second, it_b->second, NullEqualsNull)) {
144-
continue;
157+
auto keysCompareResult = dispatchDynamicVariantEquality(
158+
it_a->first, it_b->first, nullHandlingMode);
159+
if (keysCompareResult.has_value()) {
160+
if (!keysCompareResult.value()) {
161+
return false;
162+
}
145163
} else {
146-
return false;
164+
isComparisonIndeterminate = true;
165+
}
166+
167+
auto valuesCompareResult = dispatchDynamicVariantEquality(
168+
it_a->second, it_b->second, nullHandlingMode);
169+
if (valuesCompareResult.has_value()) {
170+
if (!valuesCompareResult.value()) {
171+
return false;
172+
}
173+
} else {
174+
isComparisonIndeterminate = true;
147175
}
148176
}
177+
if (isComparisonIndeterminate) {
178+
return std::nullopt;
179+
}
149180
return true;
150181
}
151182
};
152183

153-
bool dispatchDynamicVariantEquality(
184+
std::optional<bool> dispatchDynamicVariantEquality(
154185
const Variant& a,
155186
const Variant& b,
156-
const bool& enableNullEqualsNull) {
157-
if (enableNullEqualsNull) {
158-
return VELOX_DYNAMIC_TYPE_DISPATCH_METHOD(
159-
VariantEquality, equals<true>, a.kind(), a, b);
187+
CompareFlags::NullHandlingMode nullHandlingMode) {
188+
if (nullHandlingMode == CompareFlags::NullHandlingMode::kNullAsValue) {
189+
return VELOX_DYNAMIC_TYPE_DISPATCH_METHOD_ALL(
190+
VariantEquality,
191+
equals<CompareFlags::NullHandlingMode::kNullAsValue>,
192+
a.kind(),
193+
a,
194+
b);
160195
}
161-
return VELOX_DYNAMIC_TYPE_DISPATCH_METHOD(
162-
VariantEquality, equals<false>, a.kind(), a, b);
196+
return VELOX_DYNAMIC_TYPE_DISPATCH_METHOD_ALL(
197+
VariantEquality,
198+
equals<CompareFlags::NullHandlingMode::kNullAsIndeterminate>,
199+
a.kind(),
200+
a,
201+
b);
163202
}
164203

165204
} // namespace
@@ -876,14 +915,24 @@ bool Variant::equals(const Variant& other) const {
876915
return value<KIND>() == other.value<KIND>();
877916
}
878917

879-
bool Variant::equals(const Variant& other) const {
918+
std::optional<bool> Variant::equals(
919+
const Variant& other,
920+
CompareFlags::NullHandlingMode nullHandlingMode) const {
880921
if (other.kind_ != this->kind_) {
881922
return false;
882923
}
883-
if (other.isNull()) {
884-
return this->isNull();
924+
if (nullHandlingMode == CompareFlags::NullHandlingMode::kNullAsValue &&
925+
!this->isNull() && !other.isNull()) {
926+
return VELOX_DYNAMIC_TYPE_DISPATCH_ALL(equals, kind_, other);
885927
}
886-
return VELOX_DYNAMIC_TYPE_DISPATCH_ALL(equals, kind_, other);
928+
return dispatchDynamicVariantEquality(*this, other, nullHandlingMode);
929+
}
930+
931+
bool Variant::equals(const Variant& other) const {
932+
std::optional<bool> compareResult =
933+
this->equals(other, CompareFlags::NullHandlingMode::kNullAsValue);
934+
VELOX_CHECK(compareResult.has_value());
935+
return compareResult.value();
887936
}
888937

889938
template <TypeKind KIND>
@@ -1107,13 +1156,6 @@ void Variant::verifyArrayElements(const std::vector<Variant>& inputs) {
11071156
}
11081157
}
11091158

1110-
bool Variant::equalsWithNullEqualsNull(const Variant& other) const {
1111-
if (other.kind_ != this->kind_) {
1112-
return false;
1113-
}
1114-
return dispatchDynamicVariantEquality(*this, other, true);
1115-
}
1116-
11171159
TypePtr Variant::inferType() const {
11181160
switch (kind_) {
11191161
case TypeKind::MAP: {

velox/type/Variant.h

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <folly/Conv.h>
2323

2424
#include "folly/dynamic.h"
25+
#include "velox/common/base/CompareFlags.h"
2526
#include "velox/common/base/Exceptions.h"
2627
#include "velox/type/Conversions.h"
2728
#include "velox/type/CppToType.h"
@@ -617,7 +618,10 @@ class Variant {
617618

618619
struct NullEqualsNullsComparator {
619620
bool operator()(const Variant& a, const Variant& b) const {
620-
return a.equalsWithNullEqualsNull(b);
621+
auto compareResult =
622+
a.equals(b, CompareFlags::NullHandlingMode::kNullAsValue);
623+
VELOX_CHECK(compareResult.has_value());
624+
return compareResult.value();
621625
}
622626
};
623627

@@ -648,10 +652,25 @@ class Variant {
648652

649653
bool operator<(const Variant& other) const;
650654

655+
/// Compares two `Variant`s using the provided `nullHandlingMode`.
656+
///
657+
/// Returns:
658+
/// - `std::nullopt` when `nullHandlingMode` is `kNullAsIndeterminate` and
659+
/// the comparison is indeterminate.
660+
/// - `true` or `false` when `nullHandlingMode` is `kNullAsValue` and the
661+
/// comparison is determinate.
662+
///
663+
/// See `CompareFlags::NullHandlingMode` for the interpretation of the
664+
/// null-handling semantics for different types.
665+
std::optional<bool> equals(
666+
const Variant& other,
667+
CompareFlags::NullHandlingMode nullHandlingMode) const;
668+
669+
/// Shortcut for equals(..., CompareFlags::NullHandlingMode::kNullAsValue).
670+
/// Treats two null Variants as equal and returns false for any comparison
671+
/// involving a null value and a non-null value.
651672
bool equals(const Variant& other) const;
652673

653-
bool equalsWithNullEqualsNull(const Variant& other) const;
654-
655674
std::string toJson(const TypePtr& type) const;
656675

657676
std::string toJson(const Type& type) const;

0 commit comments

Comments
 (0)