Skip to content

Commit 084a365

Browse files
andishgarkou
andauthored
GH-46835: [C++] Add more configuration options to arrow::EqualOptions (#47204)
### Rationale for this change Parameterizing the comparison of `metadata` and `arrow::Schema` via `arrow::EqualOptions`. ### What changes are included in this PR? * Added two attributes to `arrow::EqualOptions` for parameterizing the comparison of `metadata` and `schema`. ### Are these changes tested? Yes, I ran the relevant unit tests. ### Are there any user-facing changes? * Added `use_schema()` to `arrow::EqualOptions`. * Added `use_metadata()` to `arrow::EqualOptions`. * Added a new overload for `arrow::RecordBatch::Equals` that accepts only another `arrow::RecordBatch` and an `arrow::EqualOptions` instance. * GitHub Issue: #46835 Lead-authored-by: Arash Andishgar <[email protected]> Co-authored-by: Sutou Kouhei <[email protected]> Signed-off-by: Sutou Kouhei <[email protected]>
1 parent 55be8c0 commit 084a365

File tree

5 files changed

+165
-60
lines changed

5 files changed

+165
-60
lines changed

cpp/src/arrow/compare.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,38 @@ class EqualOptions {
8383
return res;
8484
}
8585

86+
/// Whether the \ref arrow::Schema property is used in the comparison.
87+
///
88+
/// This option only affects the Equals methods
89+
/// and has no effect on ApproxEquals methods.
90+
bool use_schema() const { return use_schema_; }
91+
92+
/// Return a new EqualOptions object with the "use_schema_" property changed.
93+
///
94+
/// Setting this option is false making the value of \ref EqualOptions::use_metadata
95+
/// is ignored.
96+
EqualOptions use_schema(bool v) const {
97+
auto res = EqualOptions(*this);
98+
res.use_schema_ = v;
99+
return res;
100+
}
101+
102+
/// Whether the "metadata" in \ref arrow::Schema is used in the comparison.
103+
///
104+
/// This option only affects the Equals methods
105+
/// and has no effect on the ApproxEquals methods.
106+
///
107+
/// Note: This option is only considered when \ref arrow::EqualOptions::use_schema is
108+
/// set to true.
109+
bool use_metadata() const { return use_metadata_; }
110+
111+
/// Return a new EqualOptions object with the "use_metadata" property changed.
112+
EqualOptions use_metadata(bool v) const {
113+
auto res = EqualOptions(*this);
114+
res.use_metadata_ = v;
115+
return res;
116+
}
117+
86118
/// The ostream to which a diff will be formatted if arrays disagree.
87119
/// If this is null (the default) no diff will be formatted.
88120
std::ostream* diff_sink() const { return diff_sink_; }
@@ -103,6 +135,8 @@ class EqualOptions {
103135
bool nans_equal_ = false;
104136
bool signed_zeros_equal_ = true;
105137
bool use_atol_ = false;
138+
bool use_schema_ = true;
139+
bool use_metadata_ = false;
106140

107141
std::ostream* diff_sink_ = NULLPTR;
108142
};

cpp/src/arrow/record_batch.cc

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "arrow/array/statistics.h"
3636
#include "arrow/array/validate.h"
3737
#include "arrow/c/abi.h"
38+
#include "arrow/compare.h"
3839
#include "arrow/pretty_print.h"
3940
#include "arrow/status.h"
4041
#include "arrow/table.h"
@@ -349,44 +350,27 @@ bool CanIgnoreNaNInEquality(const RecordBatch& batch, const EqualOptions& opts)
349350

350351
bool RecordBatch::Equals(const RecordBatch& other, bool check_metadata,
351352
const EqualOptions& opts) const {
352-
if (this == &other) {
353-
if (CanIgnoreNaNInEquality(*this, opts)) {
354-
return true;
355-
}
356-
} else {
357-
if (num_columns() != other.num_columns() || num_rows_ != other.num_rows()) {
358-
return false;
359-
} else if (!schema_->Equals(*other.schema(), check_metadata)) {
360-
return false;
361-
} else if (device_type() != other.device_type()) {
362-
return false;
363-
}
364-
}
365-
366-
for (int i = 0; i < num_columns(); ++i) {
367-
if (!column(i)->Equals(other.column(i), opts)) {
368-
return false;
369-
}
370-
}
371-
372-
return true;
353+
return Equals(other, opts.use_metadata(check_metadata));
373354
}
374355

375-
bool RecordBatch::ApproxEquals(const RecordBatch& other, const EqualOptions& opts) const {
356+
bool RecordBatch::Equals(const RecordBatch& other, const EqualOptions& opts) const {
376357
if (this == &other) {
377358
if (CanIgnoreNaNInEquality(*this, opts)) {
378359
return true;
379360
}
380361
} else {
381362
if (num_columns() != other.num_columns() || num_rows_ != other.num_rows()) {
382363
return false;
364+
} else if (opts.use_schema() &&
365+
!schema_->Equals(*other.schema(), opts.use_metadata())) {
366+
return false;
383367
} else if (device_type() != other.device_type()) {
384368
return false;
385369
}
386370
}
387371

388372
for (int i = 0; i < num_columns(); ++i) {
389-
if (!column(i)->ApproxEquals(other.column(i), opts)) {
373+
if (!column(i)->Equals(other.column(i), opts)) {
390374
return false;
391375
}
392376
}

cpp/src/arrow/record_batch.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,22 +118,32 @@ class ARROW_EXPORT RecordBatch {
118118
static Result<std::shared_ptr<RecordBatch>> FromStructArray(
119119
const std::shared_ptr<Array>& array, MemoryPool* pool = default_memory_pool());
120120

121-
/// \brief Determine if two record batches are exactly equal
121+
/// \brief Determine if two record batches are equal
122122
///
123123
/// \param[in] other the RecordBatch to compare with
124-
/// \param[in] check_metadata if true, check that Schema metadata is the same
124+
/// \param[in] check_metadata if true, the schema metadata will be compared,
125+
/// regardless of the value set in \ref EqualOptions::use_metadata_
125126
/// \param[in] opts the options for equality comparisons
126127
/// \return true if batches are equal
127128
bool Equals(const RecordBatch& other, bool check_metadata = false,
128129
const EqualOptions& opts = EqualOptions::Defaults()) const;
129130

131+
/// \brief Determine if two record batches are equal
132+
///
133+
/// \param[in] other the RecordBatch to compare with
134+
/// \param[in] opts the options for equality comparisons
135+
/// \return true if batches are equal
136+
bool Equals(const RecordBatch& other, const EqualOptions& opts) const;
137+
130138
/// \brief Determine if two record batches are approximately equal
131139
///
132140
/// \param[in] other the RecordBatch to compare with
133141
/// \param[in] opts the options for equality comparisons
134142
/// \return true if batches are approximately equal
135143
bool ApproxEquals(const RecordBatch& other,
136-
const EqualOptions& opts = EqualOptions::Defaults()) const;
144+
const EqualOptions& opts = EqualOptions::Defaults()) const {
145+
return Equals(other, opts.use_schema(false).use_atol(true));
146+
}
137147

138148
/// \return the record batch's schema
139149
const std::shared_ptr<Schema>& schema() const { return schema_; }

cpp/src/arrow/record_batch_test.cc

Lines changed: 105 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "arrow/array/util.h"
4141
#include "arrow/c/abi.h"
4242
#include "arrow/chunked_array.h"
43+
#include "arrow/compare.h"
4344
#include "arrow/config.h"
4445
#include "arrow/status.h"
4546
#include "arrow/table.h"
@@ -64,44 +65,100 @@ class TestRecordBatch : public ::testing::Test {};
6465
TEST_F(TestRecordBatch, Equals) {
6566
const int length = 10;
6667

68+
auto f0 = field("f0", int32());
69+
auto f1 = field("f1", uint8());
70+
auto f2 = field("f2", int16());
71+
72+
auto schema = ::arrow::schema({f0, f1, f2});
73+
auto schema_same = ::arrow::schema({f0, f1, f2});
74+
auto schema_fewer_fields = ::arrow::schema({f0, f1});
75+
76+
random::RandomArrayGenerator gen(42);
77+
78+
auto a_f0 = gen.ArrayOf(int32(), length);
79+
auto a_f1 = gen.ArrayOf(uint8(), length);
80+
auto a_f2 = gen.ArrayOf(int16(), length);
81+
auto a_f0_half = a_f0->Slice(0, length / 2);
82+
auto a_f1_half = a_f1->Slice(0, length / 2);
83+
auto a_f0_different = gen.ArrayOf(int32(), length);
84+
auto a_f1_different = gen.ArrayOf(uint8(), length);
85+
86+
auto b = RecordBatch::Make(schema, length, {a_f0, a_f1, a_f2});
87+
auto b_same = RecordBatch::Make(schema_same, length, {a_f0, a_f1, a_f2});
88+
auto b_fewer_fields = RecordBatch::Make(schema_fewer_fields, length, {a_f0, a_f1});
89+
auto b_fewer_fields_half =
90+
RecordBatch::Make(schema_fewer_fields, length / 2, {a_f0_half, a_f1_half});
91+
auto b_fewer_fields_different =
92+
RecordBatch::Make(schema_fewer_fields, length, {a_f0_different, a_f1_different});
93+
94+
// Same Values
95+
ASSERT_TRUE(b->Equals(*b_same));
96+
97+
// Different number of columns
98+
ASSERT_FALSE(b->Equals(*b_fewer_fields));
99+
100+
// Different number of rows
101+
ASSERT_FALSE(b_fewer_fields->Equals(*b_fewer_fields_half));
102+
103+
// Different values
104+
ASSERT_FALSE(b_fewer_fields->Equals(*b_fewer_fields_different));
105+
}
106+
107+
class TestRecordBatchEqualOptions : public TestRecordBatch {};
108+
109+
TEST_F(TestRecordBatchEqualOptions, MetadataAndSchema) {
110+
int length = 10;
111+
67112
auto f0 = field("f0", int32());
68113
auto f1 = field("f1", uint8());
69114
auto f2 = field("f2", int16());
70115
auto f2b = field("f2b", int16());
71116

72117
auto metadata = key_value_metadata({"foo"}, {"bar"});
73118

74-
std::vector<std::shared_ptr<Field>> fields = {f0, f1, f2};
75119
auto schema = ::arrow::schema({f0, f1, f2});
76-
auto schema2 = ::arrow::schema({f0, f1});
77-
auto schema3 = ::arrow::schema({f0, f1, f2}, metadata);
78-
auto schema4 = ::arrow::schema({f0, f1, f2b});
120+
auto schema_with_metadata = ::arrow::schema({f0, f1, f2}, metadata);
121+
auto schema_renamed_field = ::arrow::schema({f0, f1, f2b});
79122

80123
random::RandomArrayGenerator gen(42);
81124

82-
auto a0 = gen.ArrayOf(int32(), length);
83-
auto a1 = gen.ArrayOf(uint8(), length);
84-
auto a2 = gen.ArrayOf(int16(), length);
125+
auto a_f0 = gen.ArrayOf(int32(), length);
126+
auto a_f1 = gen.ArrayOf(uint8(), length);
127+
auto a_f2 = gen.ArrayOf(int16(), length);
128+
auto a_f2b = a_f2;
85129

86-
auto b1 = RecordBatch::Make(schema, length, {a0, a1, a2});
87-
auto b2 = RecordBatch::Make(schema3, length, {a0, a1, a2});
88-
auto b3 = RecordBatch::Make(schema2, length, {a0, a1});
89-
auto b4 = RecordBatch::Make(schema, length, {a0, a1, a1});
90-
auto b5 = RecordBatch::Make(schema4, length, {a0, a1, a2});
130+
// All RecordBatches have the same values but different schemas.
131+
auto b = RecordBatch::Make(schema, length, {a_f0, a_f1, a_f2});
132+
auto b_with_metadata =
133+
RecordBatch::Make(schema_with_metadata, length, {a_f0, a_f1, a_f2});
134+
auto b_renamed_field =
135+
RecordBatch::Make(schema_renamed_field, length, {a_f0, a_f1, a_f2b});
91136

92-
ASSERT_TRUE(b1->Equals(*b1));
93-
ASSERT_FALSE(b1->Equals(*b3));
94-
ASSERT_FALSE(b1->Equals(*b4));
137+
auto options = EqualOptions::Defaults();
95138

96139
// Same values and types, but different field names
97-
ASSERT_FALSE(b1->Equals(*b5));
140+
ASSERT_FALSE(b->Equals(*b_renamed_field));
141+
ASSERT_TRUE(b->Equals(*b_renamed_field, options.use_schema(false)));
142+
ASSERT_TRUE(b->ApproxEquals(*b_renamed_field));
143+
ASSERT_TRUE(b->ApproxEquals(*b_renamed_field, options.use_schema(true)));
98144

99145
// Different metadata
100-
ASSERT_TRUE(b1->Equals(*b2));
101-
ASSERT_FALSE(b1->Equals(*b2, /*check_metadata=*/true));
146+
ASSERT_TRUE(b->Equals(*b_with_metadata));
147+
ASSERT_TRUE(b->Equals(*b_with_metadata, options));
148+
ASSERT_FALSE(b->Equals(*b_with_metadata,
149+
/*check_metadata=*/true));
150+
ASSERT_FALSE(b->Equals(*b_with_metadata,
151+
/*check_metadata=*/true, options.use_schema(true)));
152+
ASSERT_TRUE(b->Equals(*b_with_metadata,
153+
/*check_metadata=*/true, options.use_schema(false)));
154+
ASSERT_TRUE(b->Equals(*b_with_metadata, options.use_schema(true).use_metadata(false)));
155+
ASSERT_FALSE(b->Equals(*b_with_metadata, options.use_schema(true).use_metadata(true)));
156+
ASSERT_TRUE(b->Equals(*b_with_metadata, options.use_schema(false).use_metadata(true)));
157+
ASSERT_TRUE(
158+
b->ApproxEquals(*b_with_metadata, options.use_schema(true).use_metadata(true)));
102159
}
103160

104-
TEST_F(TestRecordBatch, EqualOptions) {
161+
TEST_F(TestRecordBatchEqualOptions, NaN) {
105162
int length = 2;
106163
auto f = field("f", float64());
107164

@@ -114,13 +171,27 @@ TEST_F(TestRecordBatch, EqualOptions) {
114171
auto b1 = RecordBatch::Make(schema, length, {array1});
115172
auto b2 = RecordBatch::Make(schema, length, {array2});
116173

117-
EXPECT_FALSE(b1->Equals(*b2, /*check_metadata=*/false,
118-
EqualOptions::Defaults().nans_equal(false)));
119-
EXPECT_TRUE(b1->Equals(*b2, /*check_metadata=*/false,
120-
EqualOptions::Defaults().nans_equal(true)));
174+
EXPECT_FALSE(b1->Equals(*b2, EqualOptions::Defaults().nans_equal(false)));
175+
EXPECT_TRUE(b1->Equals(*b2, EqualOptions::Defaults().nans_equal(true)));
176+
}
177+
178+
TEST_F(TestRecordBatchEqualOptions, SignedZero) {
179+
int length = 2;
180+
auto f = field("f", float64());
181+
182+
auto schema = ::arrow::schema({f});
183+
184+
std::shared_ptr<Array> array1, array2;
185+
ArrayFromVector<DoubleType>(float64(), {true, true}, {0.5, +0.0}, &array1);
186+
ArrayFromVector<DoubleType>(float64(), {true, true}, {0.5, -0.0}, &array2);
187+
auto b1 = RecordBatch::Make(schema, length, {array1});
188+
auto b2 = RecordBatch::Make(schema, length, {array2});
189+
190+
ASSERT_FALSE(b1->Equals(*b2, EqualOptions::Defaults().signed_zeros_equal(false)));
191+
ASSERT_TRUE(b1->Equals(*b2, EqualOptions::Defaults().signed_zeros_equal(true)));
121192
}
122193

123-
TEST_F(TestRecordBatch, ApproxEqualOptions) {
194+
TEST_F(TestRecordBatchEqualOptions, Approx) {
124195
int length = 2;
125196
auto f = field("f", float64());
126197

@@ -137,8 +208,8 @@ TEST_F(TestRecordBatch, ApproxEqualOptions) {
137208
EXPECT_FALSE(b1->ApproxEquals(*b2, EqualOptions::Defaults().nans_equal(true)));
138209

139210
auto options = EqualOptions::Defaults().nans_equal(true).atol(0.1);
140-
EXPECT_FALSE(b1->Equals(*b2, false, options));
141-
EXPECT_TRUE(b1->Equals(*b2, false, options.use_atol(true)));
211+
EXPECT_FALSE(b1->Equals(*b2, options));
212+
EXPECT_TRUE(b1->Equals(*b2, options.use_atol(true)));
142213
EXPECT_TRUE(b1->ApproxEquals(*b2, options));
143214
}
144215

@@ -158,8 +229,8 @@ TEST_F(TestRecordBatchEqualsSameAddress, NonFloatType) {
158229

159230
auto options = EqualOptions::Defaults();
160231

161-
ASSERT_TRUE(b0->Equals(*b1, true, options));
162-
ASSERT_TRUE(b0->Equals(*b1, true, options.nans_equal(true)));
232+
ASSERT_TRUE(b0->Equals(*b1, options));
233+
ASSERT_TRUE(b0->Equals(*b1, options.nans_equal(true)));
163234

164235
ASSERT_TRUE(b0->ApproxEquals(*b1, options));
165236
ASSERT_TRUE(b0->ApproxEquals(*b1, options.nans_equal(true)));
@@ -180,8 +251,8 @@ TEST_F(TestRecordBatchEqualsSameAddress, NestedTypesWithoutFloatType) {
180251

181252
auto options = EqualOptions::Defaults();
182253

183-
ASSERT_TRUE(b0->Equals(*b1, true, options));
184-
ASSERT_TRUE(b0->Equals(*b1, true, options.nans_equal(true)));
254+
ASSERT_TRUE(b0->Equals(*b1, options));
255+
ASSERT_TRUE(b0->Equals(*b1, options.nans_equal(true)));
185256

186257
ASSERT_TRUE(b0->ApproxEquals(*b1, options));
187258
ASSERT_TRUE(b0->ApproxEquals(*b1, options.nans_equal(true)));
@@ -201,8 +272,8 @@ TEST_F(TestRecordBatchEqualsSameAddress, FloatType) {
201272

202273
auto options = EqualOptions::Defaults();
203274

204-
ASSERT_FALSE(b0->Equals(*b1, true, options));
205-
ASSERT_TRUE(b0->Equals(*b1, true, options.nans_equal(true)));
275+
ASSERT_FALSE(b0->Equals(*b1, options));
276+
ASSERT_TRUE(b0->Equals(*b1, options.nans_equal(true)));
206277

207278
ASSERT_FALSE(b0->ApproxEquals(*b1, options));
208279
ASSERT_TRUE(b0->ApproxEquals(*b1, options.nans_equal(true)));
@@ -223,8 +294,8 @@ TEST_F(TestRecordBatchEqualsSameAddress, NestedTypesWithFloatType) {
223294

224295
auto options = EqualOptions::Defaults();
225296

226-
ASSERT_FALSE(b0->Equals(*b1, true, options));
227-
ASSERT_TRUE(b0->Equals(*b1, true, options.nans_equal(true)));
297+
ASSERT_FALSE(b0->Equals(*b1, options));
298+
ASSERT_TRUE(b0->Equals(*b1, options.nans_equal(true)));
228299

229300
ASSERT_FALSE(b0->ApproxEquals(*b1, options));
230301
ASSERT_TRUE(b0->ApproxEquals(*b1, options.nans_equal(true)));

docs/source/cpp/api/utilities.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ Iterators
4343
.. doxygenclass:: arrow::VectorIterator
4444
:members:
4545

46+
Comparison
47+
==========
48+
49+
.. doxygenclass:: arrow::EqualOptions
50+
:members:
51+
4652
Compression
4753
===========
4854

0 commit comments

Comments
 (0)