Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions cpp/src/arrow/compare.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,38 @@ class EqualOptions {
return res;
}

/// Whether the \ref arrow::Schema property is used in the comparison.
///
/// This option only affects the Equals methods
/// and has no effect on ApproxEquals methods.
bool use_schema() const { return use_schema_; }

/// Return a new EqualOptions object with the "use_schema_" property changed.
///
/// Setting this option is false making the value of \ref EqualOptions::use_metadata_
/// is ignored.
EqualOptions use_schema(bool v) const {
auto res = EqualOptions(*this);
res.use_schema_ = v;
return res;
}

/// Whether the "metadata" in \ref arrow::Schema is used in the comparison.
///
/// This option only affects the Equals methods
/// and has no effect on the ApproxEquals methods.
///
/// Note: This option is only considered when \ref arrow::EqualOptions::use_schema is
/// set to true.
bool use_metadata() const { return use_metadata_; }

/// Return a new EqualOptions object with the "use_metadata" property changed.
EqualOptions use_metadata(bool v) const {
auto res = EqualOptions(*this);
res.use_metadata_ = v;
return res;
}

/// The ostream to which a diff will be formatted if arrays disagree.
/// If this is null (the default) no diff will be formatted.
std::ostream* diff_sink() const { return diff_sink_; }
Expand All @@ -103,6 +135,8 @@ class EqualOptions {
bool nans_equal_ = false;
bool signed_zeros_equal_ = true;
bool use_atol_ = false;
bool use_schema_ = true;
bool use_metadata_ = false;

std::ostream* diff_sink_ = NULLPTR;
};
Expand Down
30 changes: 7 additions & 23 deletions cpp/src/arrow/record_batch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "arrow/array/statistics.h"
#include "arrow/array/validate.h"
#include "arrow/c/abi.h"
#include "arrow/compare.h"
#include "arrow/pretty_print.h"
#include "arrow/status.h"
#include "arrow/table.h"
Expand Down Expand Up @@ -349,44 +350,27 @@ bool CanIgnoreNaNInEquality(const RecordBatch& batch, const EqualOptions& opts)

bool RecordBatch::Equals(const RecordBatch& other, bool check_metadata,
const EqualOptions& opts) const {
if (this == &other) {
if (CanIgnoreNaNInEquality(*this, opts)) {
return true;
}
} else {
if (num_columns() != other.num_columns() || num_rows_ != other.num_rows()) {
return false;
} else if (!schema_->Equals(*other.schema(), check_metadata)) {
return false;
} else if (device_type() != other.device_type()) {
return false;
}
}

for (int i = 0; i < num_columns(); ++i) {
if (!column(i)->Equals(other.column(i), opts)) {
return false;
}
}

return true;
return Equals(other, opts.use_metadata(check_metadata));
}

bool RecordBatch::ApproxEquals(const RecordBatch& other, const EqualOptions& opts) const {
bool RecordBatch::Equals(const RecordBatch& other, const EqualOptions& opts) const {
if (this == &other) {
if (CanIgnoreNaNInEquality(*this, opts)) {
return true;
}
} else {
if (num_columns() != other.num_columns() || num_rows_ != other.num_rows()) {
return false;
} else if (opts.use_schema() &&
!schema_->Equals(*other.schema(), opts.use_metadata())) {
return false;
} else if (device_type() != other.device_type()) {
return false;
}
}

for (int i = 0; i < num_columns(); ++i) {
if (!column(i)->ApproxEquals(other.column(i), opts)) {
if (!column(i)->Equals(other.column(i), opts)) {
return false;
}
}
Expand Down
16 changes: 13 additions & 3 deletions cpp/src/arrow/record_batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,22 +118,32 @@ class ARROW_EXPORT RecordBatch {
static Result<std::shared_ptr<RecordBatch>> FromStructArray(
const std::shared_ptr<Array>& array, MemoryPool* pool = default_memory_pool());

/// \brief Determine if two record batches are exactly equal
/// \brief Determine if two record batches are equal
///
/// \param[in] other the RecordBatch to compare with
/// \param[in] check_metadata if true, check that Schema metadata is the same
/// \param[in] check_metadata if true, the schema metadata will be compared,
/// regardless of the value set in \ref EqualOptions::use_metadata_
/// \param[in] opts the options for equality comparisons
/// \return true if batches are equal
bool Equals(const RecordBatch& other, bool check_metadata = false,
const EqualOptions& opts = EqualOptions::Defaults()) const;

/// \brief Determine if two record batches are equal
///
/// \param[in] other the RecordBatch to compare with
/// \param[in] opts the options for equality comparisons
/// \return true if batches are equal
bool Equals(const RecordBatch& other, const EqualOptions& opts) const;

/// \brief Determine if two record batches are approximately equal
///
/// \param[in] other the RecordBatch to compare with
/// \param[in] opts the options for equality comparisons
/// \return true if batches are approximately equal
bool ApproxEquals(const RecordBatch& other,
const EqualOptions& opts = EqualOptions::Defaults()) const;
const EqualOptions& opts = EqualOptions::Defaults()) const {
return Equals(other, opts.use_schema(false).use_atol(true));
}

/// \return the record batch's schema
const std::shared_ptr<Schema>& schema() const { return schema_; }
Expand Down
139 changes: 105 additions & 34 deletions cpp/src/arrow/record_batch_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "arrow/array/util.h"
#include "arrow/c/abi.h"
#include "arrow/chunked_array.h"
#include "arrow/compare.h"
#include "arrow/config.h"
#include "arrow/status.h"
#include "arrow/table.h"
Expand All @@ -64,44 +65,100 @@ class TestRecordBatch : public ::testing::Test {};
TEST_F(TestRecordBatch, Equals) {
const int length = 10;

auto f0 = field("f0", int32());
auto f1 = field("f1", uint8());
auto f2 = field("f2", int16());

auto schema = ::arrow::schema({f0, f1, f2});
auto schema_same = ::arrow::schema({f0, f1, f2});
auto schema_fewer_fields = ::arrow::schema({f0, f1});

random::RandomArrayGenerator gen(42);

auto a_f0 = gen.ArrayOf(int32(), length);
auto a_f1 = gen.ArrayOf(uint8(), length);
auto a_f2 = gen.ArrayOf(int16(), length);
auto a_f0_half = a_f0->Slice(0, length / 2);
auto a_f1_half = a_f1->Slice(0, length / 2);
auto a_f0_different = gen.ArrayOf(int32(), length);
auto a_f1_different = gen.ArrayOf(uint8(), length);

auto b = RecordBatch::Make(schema, length, {a_f0, a_f1, a_f2});
auto b_same = RecordBatch::Make(schema_same, length, {a_f0, a_f1, a_f2});
auto b_fewer_fields = RecordBatch::Make(schema_fewer_fields, length, {a_f0, a_f1});
auto b_fewer_fields_half =
RecordBatch::Make(schema_fewer_fields, length / 2, {a_f0_half, a_f1_half});
auto b_fewer_fields_different =
RecordBatch::Make(schema_fewer_fields, length, {a_f0_different, a_f1_different});

// Same Values
ASSERT_TRUE(b->Equals(*b_same));

// Different number of columns
ASSERT_FALSE(b->Equals(*b_fewer_fields));

// Different number of rows
ASSERT_FALSE(b_fewer_fields->Equals(*b_fewer_fields_half));

// Different values
ASSERT_FALSE(b_fewer_fields->Equals(*b_fewer_fields_different));
}

class TestRecordBatchEqualOptions : public TestRecordBatch {};

TEST_F(TestRecordBatchEqualOptions, MetadataAndSchema) {
int length = 10;

auto f0 = field("f0", int32());
auto f1 = field("f1", uint8());
auto f2 = field("f2", int16());
auto f2b = field("f2b", int16());

auto metadata = key_value_metadata({"foo"}, {"bar"});

std::vector<std::shared_ptr<Field>> fields = {f0, f1, f2};
auto schema = ::arrow::schema({f0, f1, f2});
auto schema2 = ::arrow::schema({f0, f1});
auto schema3 = ::arrow::schema({f0, f1, f2}, metadata);
auto schema4 = ::arrow::schema({f0, f1, f2b});
auto schema_with_metadata = ::arrow::schema({f0, f1, f2}, metadata);
auto schema_renamed_field = ::arrow::schema({f0, f1, f2b});

random::RandomArrayGenerator gen(42);

auto a0 = gen.ArrayOf(int32(), length);
auto a1 = gen.ArrayOf(uint8(), length);
auto a2 = gen.ArrayOf(int16(), length);
auto a_f0 = gen.ArrayOf(int32(), length);
auto a_f1 = gen.ArrayOf(uint8(), length);
auto a_f2 = gen.ArrayOf(int16(), length);
auto a_f2b = a_f2;

auto b1 = RecordBatch::Make(schema, length, {a0, a1, a2});
auto b2 = RecordBatch::Make(schema3, length, {a0, a1, a2});
auto b3 = RecordBatch::Make(schema2, length, {a0, a1});
auto b4 = RecordBatch::Make(schema, length, {a0, a1, a1});
auto b5 = RecordBatch::Make(schema4, length, {a0, a1, a2});
// All RecordBatches have the same values but different schemas.
auto b = RecordBatch::Make(schema, length, {a_f0, a_f1, a_f2});
auto b_with_metadata =
RecordBatch::Make(schema_with_metadata, length, {a_f0, a_f1, a_f2});
auto b_renamed_field =
RecordBatch::Make(schema_renamed_field, length, {a_f0, a_f1, a_f2b});

ASSERT_TRUE(b1->Equals(*b1));
ASSERT_FALSE(b1->Equals(*b3));
ASSERT_FALSE(b1->Equals(*b4));
auto options = EqualOptions::Defaults();

// Same values and types, but different field names
ASSERT_FALSE(b1->Equals(*b5));
ASSERT_FALSE(b->Equals(*b_renamed_field));
ASSERT_TRUE(b->Equals(*b_renamed_field, options.use_schema(false)));
ASSERT_TRUE(b->ApproxEquals(*b_renamed_field));
ASSERT_TRUE(b->ApproxEquals(*b_renamed_field, options.use_schema(true)));

// Different metadata
ASSERT_TRUE(b1->Equals(*b2));
ASSERT_FALSE(b1->Equals(*b2, /*check_metadata=*/true));
ASSERT_TRUE(b->Equals(*b_with_metadata));
ASSERT_TRUE(b->Equals(*b_with_metadata, options));
ASSERT_FALSE(b->Equals(*b_with_metadata,
/*check_metadata=*/true));
ASSERT_FALSE(b->Equals(*b_with_metadata,
/*check_metadata=*/true, options.use_schema(true)));
ASSERT_TRUE(b->Equals(*b_with_metadata,
/*check_metadata=*/true, options.use_schema(false)));
ASSERT_TRUE(b->Equals(*b_with_metadata, options.use_schema(true).use_metadata(false)));
ASSERT_FALSE(b->Equals(*b_with_metadata, options.use_schema(true).use_metadata(true)));
ASSERT_TRUE(b->Equals(*b_with_metadata, options.use_schema(false).use_metadata(true)));
ASSERT_TRUE(
b->ApproxEquals(*b_with_metadata, options.use_schema(true).use_metadata(true)));
}

TEST_F(TestRecordBatch, EqualOptions) {
TEST_F(TestRecordBatchEqualOptions, NaN) {
int length = 2;
auto f = field("f", float64());

Expand All @@ -114,13 +171,27 @@ TEST_F(TestRecordBatch, EqualOptions) {
auto b1 = RecordBatch::Make(schema, length, {array1});
auto b2 = RecordBatch::Make(schema, length, {array2});

EXPECT_FALSE(b1->Equals(*b2, /*check_metadata=*/false,
EqualOptions::Defaults().nans_equal(false)));
EXPECT_TRUE(b1->Equals(*b2, /*check_metadata=*/false,
EqualOptions::Defaults().nans_equal(true)));
EXPECT_FALSE(b1->Equals(*b2, EqualOptions::Defaults().nans_equal(false)));
EXPECT_TRUE(b1->Equals(*b2, EqualOptions::Defaults().nans_equal(true)));
}

TEST_F(TestRecordBatchEqualOptions, SignedZero) {
int length = 2;
auto f = field("f", float64());

auto schema = ::arrow::schema({f});

std::shared_ptr<Array> array1, array2;
ArrayFromVector<DoubleType>(float64(), {true, true}, {0.5, +0.0}, &array1);
ArrayFromVector<DoubleType>(float64(), {true, true}, {0.5, -0.0}, &array2);
auto b1 = RecordBatch::Make(schema, length, {array1});
auto b2 = RecordBatch::Make(schema, length, {array2});

ASSERT_FALSE(b1->Equals(*b2, EqualOptions::Defaults().signed_zeros_equal(false)));
ASSERT_TRUE(b1->Equals(*b2, EqualOptions::Defaults().signed_zeros_equal(true)));
}

TEST_F(TestRecordBatch, ApproxEqualOptions) {
TEST_F(TestRecordBatchEqualOptions, Approx) {
int length = 2;
auto f = field("f", float64());

Expand All @@ -137,8 +208,8 @@ TEST_F(TestRecordBatch, ApproxEqualOptions) {
EXPECT_FALSE(b1->ApproxEquals(*b2, EqualOptions::Defaults().nans_equal(true)));

auto options = EqualOptions::Defaults().nans_equal(true).atol(0.1);
EXPECT_FALSE(b1->Equals(*b2, false, options));
EXPECT_TRUE(b1->Equals(*b2, false, options.use_atol(true)));
EXPECT_FALSE(b1->Equals(*b2, options));
EXPECT_TRUE(b1->Equals(*b2, options.use_atol(true)));
EXPECT_TRUE(b1->ApproxEquals(*b2, options));
}

Expand All @@ -158,8 +229,8 @@ TEST_F(TestRecordBatchEqualsSameAddress, NonFloatType) {

auto options = EqualOptions::Defaults();

ASSERT_TRUE(b0->Equals(*b1, true, options));
ASSERT_TRUE(b0->Equals(*b1, true, options.nans_equal(true)));
ASSERT_TRUE(b0->Equals(*b1, options));
ASSERT_TRUE(b0->Equals(*b1, options.nans_equal(true)));

ASSERT_TRUE(b0->ApproxEquals(*b1, options));
ASSERT_TRUE(b0->ApproxEquals(*b1, options.nans_equal(true)));
Expand All @@ -180,8 +251,8 @@ TEST_F(TestRecordBatchEqualsSameAddress, NestedTypesWithoutFloatType) {

auto options = EqualOptions::Defaults();

ASSERT_TRUE(b0->Equals(*b1, true, options));
ASSERT_TRUE(b0->Equals(*b1, true, options.nans_equal(true)));
ASSERT_TRUE(b0->Equals(*b1, options));
ASSERT_TRUE(b0->Equals(*b1, options.nans_equal(true)));

ASSERT_TRUE(b0->ApproxEquals(*b1, options));
ASSERT_TRUE(b0->ApproxEquals(*b1, options.nans_equal(true)));
Expand All @@ -201,8 +272,8 @@ TEST_F(TestRecordBatchEqualsSameAddress, FloatType) {

auto options = EqualOptions::Defaults();

ASSERT_FALSE(b0->Equals(*b1, true, options));
ASSERT_TRUE(b0->Equals(*b1, true, options.nans_equal(true)));
ASSERT_FALSE(b0->Equals(*b1, options));
ASSERT_TRUE(b0->Equals(*b1, options.nans_equal(true)));

ASSERT_FALSE(b0->ApproxEquals(*b1, options));
ASSERT_TRUE(b0->ApproxEquals(*b1, options.nans_equal(true)));
Expand All @@ -223,8 +294,8 @@ TEST_F(TestRecordBatchEqualsSameAddress, NestedTypesWithFloatType) {

auto options = EqualOptions::Defaults();

ASSERT_FALSE(b0->Equals(*b1, true, options));
ASSERT_TRUE(b0->Equals(*b1, true, options.nans_equal(true)));
ASSERT_FALSE(b0->Equals(*b1, options));
ASSERT_TRUE(b0->Equals(*b1, options.nans_equal(true)));

ASSERT_FALSE(b0->ApproxEquals(*b1, options));
ASSERT_TRUE(b0->ApproxEquals(*b1, options.nans_equal(true)));
Expand Down
6 changes: 6 additions & 0 deletions docs/source/cpp/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ Iterators
.. doxygenclass:: arrow::VectorIterator
:members:

Comparison
==========

.. doxygenclass:: arrow::EqualOptions
:members:

Compression
===========

Expand Down
Loading