2525namespace facebook ::velox {
2626namespace {
2727
28- bool dispatchDynamicVariantEquality (
28+ std::optional< bool > dispatchDynamicVariantEquality (
2929 const Variant& a,
3030 const Variant& b,
31- const bool & enableNullEqualsNull );
31+ CompareFlags::NullHandlingMode nullHandlingMode );
3232
33- template <bool nullEqualsNull>
34- bool evaluateNullEquality ( const Variant& a, const Variant& b) {
35- if constexpr (nullEqualsNull) {
36- if (a. isNull () && b. isNull () ) {
37- return true ;
38- }
33+ std::optional <bool > evaluateNullEquality (
34+ const Variant& a,
35+ const Variant& b,
36+ CompareFlags::NullHandlingMode nullHandlingMode ) {
37+ if (nullHandlingMode == CompareFlags::NullHandlingMode:: kNullAsValue ) {
38+ return a. isNull () && b. isNull ();
3939 }
40- return false ;
40+ return std:: nullopt ;
4141}
4242
4343template <TypeKind KIND>
@@ -46,10 +46,10 @@ struct VariantEquality;
4646// scalars
4747template <TypeKind KIND>
4848struct VariantEquality {
49- template <bool NullEqualsNull >
50- static bool equals (const Variant& a, const Variant& b) {
49+ template <CompareFlags::NullHandlingMode nullHandlingMode >
50+ static std::optional< bool > equals (const Variant& a, const Variant& b) {
5151 if (a.isNull () || b.isNull ()) {
52- return evaluateNullEquality<NullEqualsNull> (a, b);
52+ return evaluateNullEquality (a, b, nullHandlingMode );
5353 }
5454 return a.value <KIND>() == b.value <KIND>();
5555 }
@@ -58,10 +58,10 @@ struct VariantEquality {
5858// timestamp
5959template <>
6060struct VariantEquality <TypeKind::TIMESTAMP> {
61- template <bool NullEqualsNull >
62- static bool equals (const Variant& a, const Variant& b) {
61+ template <CompareFlags::NullHandlingMode nullHandlingMode >
62+ static std::optional< bool > equals (const Variant& a, const Variant& b) {
6363 if (a.isNull () || b.isNull ()) {
64- return evaluateNullEquality<NullEqualsNull> (a, b);
64+ return evaluateNullEquality (a, b, nullHandlingMode );
6565 } else {
6666 return a.value <TypeKind::TIMESTAMP>() == b.value <TypeKind::TIMESTAMP>();
6767 }
@@ -71,34 +71,42 @@ struct VariantEquality<TypeKind::TIMESTAMP> {
7171// array
7272template <>
7373struct VariantEquality <TypeKind::ARRAY> {
74- template <bool NullEqualsNull >
75- static bool equals (const Variant& a, const Variant& b) {
74+ template <CompareFlags::NullHandlingMode nullHandlingMode >
75+ static std::optional< bool > equals (const Variant& a, const Variant& b) {
7676 if (a.isNull () || b.isNull ()) {
77- return evaluateNullEquality<NullEqualsNull> (a, b);
77+ return evaluateNullEquality (a, b, nullHandlingMode );
7878 }
7979 auto & aArray = a.value <TypeKind::ARRAY>();
8080 auto & bArray = b.value <TypeKind::ARRAY>();
8181 if (aArray.size () != bArray.size ()) {
8282 return false ;
8383 }
84+ bool isComparisonIndeterminate = false ;
8485 for (size_t i = 0 ; i != aArray.size (); ++i) {
8586 // todo(youknowjack): switch outside the loop
86- bool result =
87- dispatchDynamicVariantEquality (aArray[i], bArray[i], NullEqualsNull);
88- if (!result) {
89- return false ;
87+ auto compareResult = dispatchDynamicVariantEquality (
88+ aArray[i], bArray[i], nullHandlingMode);
89+ if (compareResult.has_value ()) {
90+ if (!compareResult.value ()) {
91+ return false ;
92+ }
93+ } else {
94+ isComparisonIndeterminate = true ;
9095 }
9196 }
97+ if (isComparisonIndeterminate) {
98+ return std::nullopt ;
99+ }
92100 return true ;
93101 }
94102};
95103
96104template <>
97105struct VariantEquality <TypeKind::ROW> {
98- template <bool NullEqualsNull >
99- static bool equals (const Variant& a, const Variant& b) {
106+ template <CompareFlags::NullHandlingMode nullHandlingMode >
107+ static std::optional< bool > equals (const Variant& a, const Variant& b) {
100108 if (a.isNull () || b.isNull ()) {
101- return evaluateNullEquality<NullEqualsNull> (a, b);
109+ return evaluateNullEquality (a, b, nullHandlingMode );
102110 }
103111 auto & aRow = a.value <TypeKind::ROW>();
104112 auto & bRow = b.value <TypeKind::ROW>();
@@ -108,23 +116,31 @@ struct VariantEquality<TypeKind::ROW> {
108116 return false ;
109117 }
110118 // compare array values
119+ bool isComparisonIndeterminate = false ;
111120 for (size_t i = 0 ; i != aRow.size (); ++i) {
112- bool result =
113- dispatchDynamicVariantEquality (aRow[i], bRow[i], NullEqualsNull);
114- if (!result) {
115- return false ;
121+ auto compareResult =
122+ dispatchDynamicVariantEquality (aRow[i], bRow[i], nullHandlingMode);
123+ if (compareResult.has_value ()) {
124+ if (!compareResult.value ()) {
125+ return false ;
126+ }
127+ } else {
128+ isComparisonIndeterminate = true ;
116129 }
117130 }
131+ if (isComparisonIndeterminate) {
132+ return std::nullopt ;
133+ }
118134 return true ;
119135 }
120136};
121137
122138template <>
123139struct VariantEquality <TypeKind::MAP> {
124- template <bool NullEqualsNull >
125- static bool equals (const Variant& a, const Variant& b) {
140+ template <CompareFlags::NullHandlingMode nullHandlingMode >
141+ static std::optional< bool > equals (const Variant& a, const Variant& b) {
126142 if (a.isNull () || b.isNull ()) {
127- return evaluateNullEquality<NullEqualsNull> (a, b);
143+ return evaluateNullEquality (a, b, nullHandlingMode );
128144 }
129145
130146 auto & aMap = a.value <TypeKind::MAP>();
@@ -134,32 +150,55 @@ struct VariantEquality<TypeKind::MAP> {
134150 return false ;
135151 }
136152 // compare map values
153+ bool isComparisonIndeterminate = false ;
137154 for (auto it_a = aMap.begin (), it_b = bMap.begin ();
138155 it_a != aMap.end () && it_b != bMap.end ();
139156 ++it_a, ++it_b) {
140- if (dispatchDynamicVariantEquality (
141- it_a->first , it_b->first , NullEqualsNull) &&
142- dispatchDynamicVariantEquality (
143- it_a->second , it_b->second , NullEqualsNull)) {
144- continue ;
157+ auto keysCompareResult = dispatchDynamicVariantEquality (
158+ it_a->first , it_b->first , nullHandlingMode);
159+ if (keysCompareResult.has_value ()) {
160+ if (!keysCompareResult.value ()) {
161+ return false ;
162+ }
145163 } else {
146- return false ;
164+ isComparisonIndeterminate = true ;
165+ }
166+
167+ auto valuesCompareResult = dispatchDynamicVariantEquality (
168+ it_a->second , it_b->second , nullHandlingMode);
169+ if (valuesCompareResult.has_value ()) {
170+ if (!valuesCompareResult.value ()) {
171+ return false ;
172+ }
173+ } else {
174+ isComparisonIndeterminate = true ;
147175 }
148176 }
177+ if (isComparisonIndeterminate) {
178+ return std::nullopt ;
179+ }
149180 return true ;
150181 }
151182};
152183
153- bool dispatchDynamicVariantEquality (
184+ std::optional< bool > dispatchDynamicVariantEquality (
154185 const Variant& a,
155186 const Variant& b,
156- const bool & enableNullEqualsNull) {
157- if (enableNullEqualsNull) {
158- return VELOX_DYNAMIC_TYPE_DISPATCH_METHOD (
159- VariantEquality, equals<true >, a.kind (), a, b);
187+ CompareFlags::NullHandlingMode nullHandlingMode) {
188+ if (nullHandlingMode == CompareFlags::NullHandlingMode::kNullAsValue ) {
189+ return VELOX_DYNAMIC_TYPE_DISPATCH_METHOD_ALL (
190+ VariantEquality,
191+ equals<CompareFlags::NullHandlingMode::kNullAsValue >,
192+ a.kind (),
193+ a,
194+ b);
160195 }
161- return VELOX_DYNAMIC_TYPE_DISPATCH_METHOD (
162- VariantEquality, equals<false >, a.kind (), a, b);
196+ return VELOX_DYNAMIC_TYPE_DISPATCH_METHOD_ALL (
197+ VariantEquality,
198+ equals<CompareFlags::NullHandlingMode::kNullAsIndeterminate >,
199+ a.kind (),
200+ a,
201+ b);
163202}
164203
165204} // namespace
@@ -876,14 +915,24 @@ bool Variant::equals(const Variant& other) const {
876915 return value<KIND>() == other.value <KIND>();
877916}
878917
879- bool Variant::equals (const Variant& other) const {
918+ std::optional<bool > Variant::equals (
919+ const Variant& other,
920+ CompareFlags::NullHandlingMode nullHandlingMode) const {
880921 if (other.kind_ != this ->kind_ ) {
881922 return false ;
882923 }
883- if (other.isNull ()) {
884- return this ->isNull ();
924+ if (nullHandlingMode == CompareFlags::NullHandlingMode::kNullAsValue &&
925+ !this ->isNull () && !other.isNull ()) {
926+ return VELOX_DYNAMIC_TYPE_DISPATCH_ALL (equals, kind_, other);
885927 }
886- return VELOX_DYNAMIC_TYPE_DISPATCH_ALL (equals, kind_, other);
928+ return dispatchDynamicVariantEquality (*this , other, nullHandlingMode);
929+ }
930+
931+ bool Variant::equals (const Variant& other) const {
932+ std::optional<bool > compareResult =
933+ this ->equals (other, CompareFlags::NullHandlingMode::kNullAsValue );
934+ VELOX_CHECK (compareResult.has_value ());
935+ return compareResult.value ();
887936}
888937
889938template <TypeKind KIND>
@@ -1107,13 +1156,6 @@ void Variant::verifyArrayElements(const std::vector<Variant>& inputs) {
11071156 }
11081157}
11091158
1110- bool Variant::equalsWithNullEqualsNull (const Variant& other) const {
1111- if (other.kind_ != this ->kind_ ) {
1112- return false ;
1113- }
1114- return dispatchDynamicVariantEquality (*this , other, true );
1115- }
1116-
11171159TypePtr Variant::inferType () const {
11181160 switch (kind_) {
11191161 case TypeKind::MAP: {
0 commit comments