Skip to content

Commit 52d615d

Browse files
committed
ARROW-10277: [C++] Support comparing scalars approximately
As discussed in #7748 (comment), we need to compare scalars approximately in some scenarios. Also: * Fix comparison of same-pointer NaN values * Fix scalar comparison result when both inputs are null (ARROW-8956) * Fix scalar kernel result type when result is null scalar Closes #8438 from liyafan82/fly_1012_scl Authored-by: Antoine Pitrou <[email protected]> Signed-off-by: Antoine Pitrou <[email protected]>
1 parent 49fb0f5 commit 52d615d

File tree

10 files changed

+444
-111
lines changed

10 files changed

+444
-111
lines changed

cpp/src/arrow/compare.cc

Lines changed: 199 additions & 94 deletions
Large diffs are not rendered by default.

cpp/src/arrow/compare.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,4 +122,12 @@ bool ARROW_EXPORT TypeEquals(const DataType& left, const DataType& right,
122122
bool ARROW_EXPORT ScalarEquals(const Scalar& left, const Scalar& right,
123123
const EqualOptions& options = EqualOptions::Defaults());
124124

125+
/// Returns true if scalars are approximately equal
126+
/// \param[in] left a Scalar
127+
/// \param[in] right a Scalar
128+
/// \param[in] options comparison options
129+
bool ARROW_EXPORT
130+
ScalarApproxEquals(const Scalar& left, const Scalar& right,
131+
const EqualOptions& options = EqualOptions::Defaults());
132+
125133
} // namespace arrow

cpp/src/arrow/compute/kernels/codegen_internal.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -512,28 +512,30 @@ struct ScalarUnary {
512512
using OutValue = typename GetOutputType<OutType>::T;
513513
using Arg0Value = typename GetViewType<Arg0Type>::T;
514514

515-
static void Array(KernelContext* ctx, const ArrayData& arg0, Datum* out) {
515+
static void ExecArray(KernelContext* ctx, const ArrayData& arg0, Datum* out) {
516516
ArrayIterator<Arg0Type> arg0_it(arg0);
517517
OutputAdapter<OutType>::Write(ctx, out, [&]() -> OutValue {
518518
return Op::template Call<OutValue, Arg0Value>(ctx, arg0_it());
519519
});
520520
}
521521

522-
static void Scalar(KernelContext* ctx, const Scalar& arg0, Datum* out) {
522+
static void ExecScalar(KernelContext* ctx, const Scalar& arg0, Datum* out) {
523+
Scalar* out_scalar = out->scalar().get();
523524
if (arg0.is_valid) {
524525
Arg0Value arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0);
526+
out_scalar->is_valid = true;
525527
BoxScalar<OutType>::Box(Op::template Call<OutValue, Arg0Value>(ctx, arg0_val),
526-
out->scalar().get());
528+
out_scalar);
527529
} else {
528-
out->value = MakeNullScalar(arg0.type);
530+
out_scalar->is_valid = false;
529531
}
530532
}
531533

532534
static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
533535
if (batch[0].kind() == Datum::ARRAY) {
534-
return Array(ctx, *batch[0].array(), out);
536+
return ExecArray(ctx, *batch[0].array(), out);
535537
} else {
536-
return Scalar(ctx, *batch[0].scalar(), out);
538+
return ExecScalar(ctx, *batch[0].scalar(), out);
537539
}
538540
}
539541
};

cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class TestBinaryArithmetic : public TestBase {
8484
auto exp = MakeScalar(expected);
8585

8686
ASSERT_OK_AND_ASSIGN(auto actual, func(left, right, options_, nullptr));
87-
AssertScalarsEqual(*exp, *actual.scalar(), /*verbose=*/true);
87+
AssertScalarsApproxEqual(*exp, *actual.scalar(), /*verbose=*/true);
8888
}
8989

9090
// (Scalar, Array)
@@ -144,8 +144,8 @@ class TestBinaryArithmetic : public TestBase {
144144
const auto expected_scalar = *expected->GetScalar(i);
145145
ASSERT_OK_AND_ASSIGN(
146146
actual, func(*left->GetScalar(i), *right->GetScalar(i), options_, nullptr));
147-
AssertScalarsEqual(*expected_scalar, *actual.scalar(), /*verbose=*/true,
148-
equal_options_);
147+
AssertScalarsApproxEqual(*expected_scalar, *actual.scalar(), /*verbose=*/true,
148+
equal_options_);
149149
}
150150
}
151151

cpp/src/arrow/compute/kernels/scalar_cast_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class TestCast : public TestBase {
9494
AssertArraysEqual(expected, *result, /*verbose=*/true);
9595

9696
if (input.type_id() == Type::DECIMAL || out_type->id() == Type::DECIMAL) {
97-
// ARROW-9194
97+
// ARROW-10835
9898
check_scalar = false;
9999
}
100100

@@ -111,7 +111,7 @@ class TestCast : public TestBase {
111111
ASSERT_RAISES(Invalid, Cast(input, out_type, options));
112112

113113
if (input.type_id() == Type::DECIMAL || out_type->id() == Type::DECIMAL) {
114-
// ARROW-9194
114+
// ARROW-10835
115115
check_scalar = false;
116116
}
117117

cpp/src/arrow/scalar.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ bool Scalar::Equals(const Scalar& other, const EqualOptions& options) const {
4444
return ScalarEquals(*this, other, options);
4545
}
4646

47+
bool Scalar::ApproxEquals(const Scalar& other, const EqualOptions& options) const {
48+
return ScalarApproxEquals(*this, other, options);
49+
}
50+
4751
struct ScalarHashImpl {
4852
static std::hash<std::string> string_hash;
4953

cpp/src/arrow/scalar.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ struct ARROW_EXPORT Scalar : public util::EqualityComparable<Scalar> {
6565
bool Equals(const Scalar& other,
6666
const EqualOptions& options = EqualOptions::Defaults()) const;
6767

68+
bool ApproxEquals(const Scalar& other,
69+
const EqualOptions& options = EqualOptions::Defaults()) const;
70+
6871
struct ARROW_EXPORT Hash {
6972
size_t operator()(const Scalar& scalar) const { return hash(scalar); }
7073

cpp/src/arrow/scalar_test.cc

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
#include <limits>
1819
#include <memory>
20+
#include <ostream>
1921
#include <string>
2022
#include <unordered_set>
2123
#include <utility>
@@ -96,6 +98,12 @@ TYPED_TEST(TestNumericScalar, Basics) {
9698
ASSERT_FALSE(one->Equals(ScalarType(2)));
9799
ASSERT_TRUE(two->Equals(ScalarType(2)));
98100
ASSERT_FALSE(two->Equals(ScalarType(3)));
101+
102+
ASSERT_TRUE(null->ApproxEquals(*null_value));
103+
ASSERT_TRUE(one->ApproxEquals(ScalarType(1)));
104+
ASSERT_FALSE(one->ApproxEquals(ScalarType(2)));
105+
ASSERT_TRUE(two->ApproxEquals(ScalarType(2)));
106+
ASSERT_FALSE(two->ApproxEquals(ScalarType(3)));
99107
}
100108

101109
TYPED_TEST(TestNumericScalar, Hashing) {
@@ -127,6 +135,199 @@ TYPED_TEST(TestNumericScalar, MakeScalar) {
127135
ASSERT_EQ(ScalarType(3), *three);
128136
}
129137

138+
template <typename T>
139+
class TestRealScalar : public ::testing::Test {
140+
public:
141+
using CType = typename T::c_type;
142+
using ScalarType = typename TypeTraits<T>::ScalarType;
143+
144+
void SetUp() {
145+
type_ = TypeTraits<T>::type_singleton();
146+
147+
scalar_val_ = std::make_shared<ScalarType>(static_cast<CType>(1));
148+
ASSERT_TRUE(scalar_val_->is_valid);
149+
150+
scalar_other_ = std::make_shared<ScalarType>(static_cast<CType>(1.1));
151+
ASSERT_TRUE(scalar_other_->is_valid);
152+
153+
const CType nan_value = std::numeric_limits<CType>::quiet_NaN();
154+
scalar_nan_ = std::make_shared<ScalarType>(nan_value);
155+
ASSERT_TRUE(scalar_nan_->is_valid);
156+
157+
const CType other_nan_value = std::numeric_limits<CType>::quiet_NaN();
158+
scalar_other_nan_ = std::make_shared<ScalarType>(other_nan_value);
159+
ASSERT_TRUE(scalar_other_nan_->is_valid);
160+
}
161+
162+
void TestNanEquals() {
163+
EqualOptions options = EqualOptions::Defaults();
164+
ASSERT_FALSE(scalar_nan_->Equals(*scalar_val_, options));
165+
ASSERT_FALSE(scalar_nan_->Equals(*scalar_nan_, options));
166+
ASSERT_FALSE(scalar_nan_->Equals(*scalar_other_nan_, options));
167+
168+
options = options.nans_equal(true);
169+
ASSERT_FALSE(scalar_nan_->Equals(*scalar_val_, options));
170+
ASSERT_TRUE(scalar_nan_->Equals(*scalar_nan_, options));
171+
ASSERT_TRUE(scalar_nan_->Equals(*scalar_other_nan_, options));
172+
}
173+
174+
void TestApproxEquals() {
175+
// The scalars are unequal with the small delta
176+
EqualOptions options = EqualOptions::Defaults().atol(0.05);
177+
ASSERT_FALSE(scalar_val_->ApproxEquals(*scalar_other_, options));
178+
ASSERT_FALSE(scalar_other_->ApproxEquals(*scalar_val_, options));
179+
ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_val_, options));
180+
ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_other_nan_, options));
181+
182+
// After enlarging the delta, they become equal
183+
options = options.atol(0.15);
184+
ASSERT_TRUE(scalar_val_->ApproxEquals(*scalar_other_, options));
185+
ASSERT_TRUE(scalar_other_->ApproxEquals(*scalar_val_, options));
186+
ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_val_, options));
187+
ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_other_nan_, options));
188+
189+
options = options.nans_equal(true);
190+
ASSERT_TRUE(scalar_val_->ApproxEquals(*scalar_other_, options));
191+
ASSERT_TRUE(scalar_other_->ApproxEquals(*scalar_val_, options));
192+
ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_val_, options));
193+
ASSERT_TRUE(scalar_nan_->ApproxEquals(*scalar_other_nan_, options));
194+
195+
options = options.atol(0.05);
196+
ASSERT_FALSE(scalar_val_->ApproxEquals(*scalar_other_, options));
197+
ASSERT_FALSE(scalar_other_->ApproxEquals(*scalar_val_, options));
198+
ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_val_, options));
199+
ASSERT_TRUE(scalar_nan_->ApproxEquals(*scalar_other_nan_, options));
200+
}
201+
202+
void TestStructOf() {
203+
auto ty = struct_({field("float", type_)});
204+
205+
StructScalar struct_val({scalar_val_}, ty);
206+
StructScalar struct_other_val({scalar_other_}, ty);
207+
StructScalar struct_nan({scalar_nan_}, ty);
208+
StructScalar struct_other_nan({scalar_other_nan_}, ty);
209+
210+
EqualOptions options = EqualOptions::Defaults().atol(0.05);
211+
ASSERT_FALSE(struct_val.Equals(struct_other_val, options));
212+
ASSERT_FALSE(struct_other_val.Equals(struct_val, options));
213+
ASSERT_FALSE(struct_nan.Equals(struct_val, options));
214+
ASSERT_FALSE(struct_nan.Equals(struct_nan, options));
215+
ASSERT_FALSE(struct_nan.Equals(struct_other_nan, options));
216+
ASSERT_FALSE(struct_val.ApproxEquals(struct_other_val, options));
217+
ASSERT_FALSE(struct_other_val.ApproxEquals(struct_val, options));
218+
ASSERT_FALSE(struct_nan.ApproxEquals(struct_val, options));
219+
ASSERT_FALSE(struct_nan.ApproxEquals(struct_nan, options));
220+
ASSERT_FALSE(struct_nan.ApproxEquals(struct_other_nan, options));
221+
222+
options = options.atol(0.15);
223+
ASSERT_FALSE(struct_val.Equals(struct_other_val, options));
224+
ASSERT_FALSE(struct_other_val.Equals(struct_val, options));
225+
ASSERT_FALSE(struct_nan.Equals(struct_val, options));
226+
ASSERT_FALSE(struct_nan.Equals(struct_nan, options));
227+
ASSERT_FALSE(struct_nan.Equals(struct_other_nan, options));
228+
ASSERT_TRUE(struct_val.ApproxEquals(struct_other_val, options));
229+
ASSERT_TRUE(struct_other_val.ApproxEquals(struct_val, options));
230+
ASSERT_FALSE(struct_nan.ApproxEquals(struct_val, options));
231+
ASSERT_FALSE(struct_nan.ApproxEquals(struct_nan, options));
232+
ASSERT_FALSE(struct_nan.ApproxEquals(struct_other_nan, options));
233+
234+
options = options.nans_equal(true);
235+
ASSERT_FALSE(struct_val.Equals(struct_other_val, options));
236+
ASSERT_FALSE(struct_other_val.Equals(struct_val, options));
237+
ASSERT_FALSE(struct_nan.Equals(struct_val, options));
238+
ASSERT_TRUE(struct_nan.Equals(struct_nan, options));
239+
ASSERT_TRUE(struct_nan.Equals(struct_other_nan, options));
240+
ASSERT_TRUE(struct_val.ApproxEquals(struct_other_val, options));
241+
ASSERT_TRUE(struct_other_val.ApproxEquals(struct_val, options));
242+
ASSERT_FALSE(struct_nan.ApproxEquals(struct_val, options));
243+
ASSERT_TRUE(struct_nan.ApproxEquals(struct_nan, options));
244+
ASSERT_TRUE(struct_nan.ApproxEquals(struct_other_nan, options));
245+
246+
options = options.atol(0.05);
247+
ASSERT_FALSE(struct_val.Equals(struct_other_val, options));
248+
ASSERT_FALSE(struct_other_val.Equals(struct_val, options));
249+
ASSERT_FALSE(struct_nan.Equals(struct_val, options));
250+
ASSERT_TRUE(struct_nan.Equals(struct_nan, options));
251+
ASSERT_TRUE(struct_nan.Equals(struct_other_nan, options));
252+
ASSERT_FALSE(struct_val.ApproxEquals(struct_other_val, options));
253+
ASSERT_FALSE(struct_other_val.ApproxEquals(struct_val, options));
254+
ASSERT_FALSE(struct_nan.ApproxEquals(struct_val, options));
255+
ASSERT_TRUE(struct_nan.ApproxEquals(struct_nan, options));
256+
ASSERT_TRUE(struct_nan.ApproxEquals(struct_other_nan, options));
257+
}
258+
259+
void TestListOf() {
260+
auto ty = list(type_);
261+
262+
ListScalar list_val(ArrayFromJSON(type_, "[0, null, 1.0]"), ty);
263+
ListScalar list_other_val(ArrayFromJSON(type_, "[0, null, 1.1]"), ty);
264+
ListScalar list_nan(ArrayFromJSON(type_, "[0, null, NaN]"), ty);
265+
ListScalar list_other_nan(ArrayFromJSON(type_, "[0, null, NaN]"), ty);
266+
267+
EqualOptions options = EqualOptions::Defaults().atol(0.05);
268+
ASSERT_TRUE(list_val.Equals(list_val, options));
269+
ASSERT_FALSE(list_val.Equals(list_other_val, options));
270+
ASSERT_FALSE(list_nan.Equals(list_val, options));
271+
ASSERT_FALSE(list_nan.Equals(list_nan, options));
272+
ASSERT_FALSE(list_nan.Equals(list_other_nan, options));
273+
ASSERT_TRUE(list_val.ApproxEquals(list_val, options));
274+
ASSERT_FALSE(list_val.ApproxEquals(list_other_val, options));
275+
ASSERT_FALSE(list_nan.ApproxEquals(list_val, options));
276+
ASSERT_FALSE(list_nan.ApproxEquals(list_nan, options));
277+
ASSERT_FALSE(list_nan.ApproxEquals(list_other_nan, options));
278+
279+
options = options.atol(0.15);
280+
ASSERT_TRUE(list_val.Equals(list_val, options));
281+
ASSERT_FALSE(list_val.Equals(list_other_val, options));
282+
ASSERT_FALSE(list_nan.Equals(list_val, options));
283+
ASSERT_FALSE(list_nan.Equals(list_nan, options));
284+
ASSERT_FALSE(list_nan.Equals(list_other_nan, options));
285+
ASSERT_TRUE(list_val.ApproxEquals(list_val, options));
286+
ASSERT_TRUE(list_val.ApproxEquals(list_other_val, options));
287+
ASSERT_FALSE(list_nan.ApproxEquals(list_val, options));
288+
ASSERT_FALSE(list_nan.ApproxEquals(list_nan, options));
289+
ASSERT_FALSE(list_nan.ApproxEquals(list_other_nan, options));
290+
291+
options = options.nans_equal(true);
292+
ASSERT_TRUE(list_val.Equals(list_val, options));
293+
ASSERT_FALSE(list_val.Equals(list_other_val, options));
294+
ASSERT_FALSE(list_nan.Equals(list_val, options));
295+
ASSERT_TRUE(list_nan.Equals(list_nan, options));
296+
ASSERT_TRUE(list_nan.Equals(list_other_nan, options));
297+
ASSERT_TRUE(list_val.ApproxEquals(list_val, options));
298+
ASSERT_TRUE(list_val.ApproxEquals(list_other_val, options));
299+
ASSERT_FALSE(list_nan.ApproxEquals(list_val, options));
300+
ASSERT_TRUE(list_nan.ApproxEquals(list_nan, options));
301+
ASSERT_TRUE(list_nan.ApproxEquals(list_other_nan, options));
302+
303+
options = options.atol(0.05);
304+
ASSERT_TRUE(list_val.Equals(list_val, options));
305+
ASSERT_FALSE(list_val.Equals(list_other_val, options));
306+
ASSERT_FALSE(list_nan.Equals(list_val, options));
307+
ASSERT_TRUE(list_nan.Equals(list_nan, options));
308+
ASSERT_TRUE(list_nan.Equals(list_other_nan, options));
309+
ASSERT_TRUE(list_val.ApproxEquals(list_val, options));
310+
ASSERT_FALSE(list_val.ApproxEquals(list_other_val, options));
311+
ASSERT_FALSE(list_nan.ApproxEquals(list_val, options));
312+
ASSERT_TRUE(list_nan.ApproxEquals(list_nan, options));
313+
ASSERT_TRUE(list_nan.ApproxEquals(list_other_nan, options));
314+
}
315+
316+
protected:
317+
std::shared_ptr<DataType> type_;
318+
std::shared_ptr<Scalar> scalar_val_, scalar_other_, scalar_nan_, scalar_other_nan_;
319+
};
320+
321+
TYPED_TEST_SUITE(TestRealScalar, RealArrowTypes);
322+
323+
TYPED_TEST(TestRealScalar, NanEquals) { this->TestNanEquals(); }
324+
325+
TYPED_TEST(TestRealScalar, ApproxEquals) { this->TestApproxEquals(); }
326+
327+
TYPED_TEST(TestRealScalar, StructOf) { this->TestStructOf(); }
328+
329+
TYPED_TEST(TestRealScalar, ListOf) { this->TestListOf(); }
330+
130331
TEST(TestDecimal128Scalar, Basics) {
131332
auto ty = decimal128(3, 2);
132333
auto pi = Decimal128Scalar(Decimal128("3.14"), ty);

cpp/src/arrow/testing/gtest_util.cc

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,20 @@ void AssertArraysApproxEqual(const Array& expected, const Array& actual, bool ve
148148

149149
void AssertScalarsEqual(const Scalar& expected, const Scalar& actual, bool verbose,
150150
const EqualOptions& options) {
151-
std::stringstream diff;
152-
// ARROW-8956, ScalarEquals returns false when both are null
153-
if (!expected.is_valid && !actual.is_valid) {
154-
// We consider both being null to be equal in this function
155-
return;
156-
}
157151
if (!expected.Equals(actual, options)) {
152+
std::stringstream diff;
153+
if (verbose) {
154+
diff << "Expected:\n" << expected.ToString();
155+
diff << "\nActual:\n" << actual.ToString();
156+
}
157+
FAIL() << diff.str();
158+
}
159+
}
160+
161+
void AssertScalarsApproxEqual(const Scalar& expected, const Scalar& actual, bool verbose,
162+
const EqualOptions& options) {
163+
if (!expected.ApproxEquals(actual, options)) {
164+
std::stringstream diff;
158165
if (verbose) {
159166
diff << "Expected:\n" << expected.ToString();
160167
diff << "\nActual:\n" << actual.ToString();

cpp/src/arrow/testing/gtest_util.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ ARROW_TESTING_EXPORT void AssertArraysApproxEqual(
171171
ARROW_TESTING_EXPORT void AssertScalarsEqual(
172172
const Scalar& expected, const Scalar& actual, bool verbose = false,
173173
const EqualOptions& options = EqualOptions::Defaults());
174+
ARROW_TESTING_EXPORT void AssertScalarsApproxEqual(
175+
const Scalar& expected, const Scalar& actual, bool verbose = false,
176+
const EqualOptions& options = EqualOptions::Defaults());
174177
ARROW_TESTING_EXPORT void AssertBatchesEqual(const RecordBatch& expected,
175178
const RecordBatch& actual,
176179
bool check_metadata = false);

0 commit comments

Comments
 (0)