Skip to content

Commit 54c6b2b

Browse files
committed
[C++] Optimize StructArray diffing with field by field comparison
1 parent 97c656b commit 54c6b2b

File tree

2 files changed

+136
-6
lines changed

2 files changed

+136
-6
lines changed

cpp/src/arrow/array/diff.cc

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,6 @@ struct UnitSlice {
9393
bool operator!=(const UnitSlice& other) const { return !(*this == other); }
9494
};
9595

96-
// FIXME(bkietz) this is inefficient;
97-
// StructArray's fields can be diffed independently then merged
98-
UnitSlice GetView(const StructArray& array, int64_t index) {
99-
return UnitSlice{&array, index};
100-
}
101-
10296
UnitSlice GetView(const UnionArray& array, int64_t index) {
10397
return UnitSlice{&array, index};
10498
}
@@ -164,6 +158,45 @@ struct DefaultValueComparator : public ValueComparator {
164158
}
165159
};
166160

161+
class StructValueComparator : public ValueComparator {
162+
private:
163+
const StructArray& base_;
164+
const StructArray& target_;
165+
std::vector<std::unique_ptr<ValueComparator>> field_comparators_;
166+
167+
public:
168+
StructValueComparator(const StructArray& base, const StructArray& target,
169+
std::vector<std::unique_ptr<ValueComparator>>&& field_comparators)
170+
: base_(base), target_(target), field_comparators_(std::move(field_comparators)) {
171+
DCHECK_EQ(*base_.type(), *target_.type());
172+
DCHECK_EQ(base_.num_fields(), static_cast<int>(field_comparators_.size()));
173+
}
174+
175+
~StructValueComparator() override = default;
176+
177+
bool Equals(int64_t base_index, int64_t target_index) override {
178+
const bool base_valid = base_.IsValid(base_index);
179+
const bool target_valid = target_.IsValid(target_index);
180+
181+
if (base_valid != target_valid) {
182+
return false;
183+
}
184+
185+
if (!base_valid) {
186+
return true; // Both null
187+
}
188+
189+
// Compare each field independently with early termination
190+
for (const auto& field_comparator : field_comparators_) {
191+
if (!field_comparator->Equals(base_index, target_index)) {
192+
return false;
193+
}
194+
}
195+
196+
return true;
197+
}
198+
};
199+
167200
template <typename RunEndCType>
168201
class REEValueComparator : public ValueComparator {
169202
private:
@@ -308,6 +341,26 @@ class ValueComparatorFactory {
308341
return Status::NotImplemented("dictionary type");
309342
}
310343

344+
Status Visit(const StructType& struct_type, const Array& base, const Array& target) {
345+
const auto& base_struct = checked_cast<const StructArray&>(base);
346+
const auto& target_struct = checked_cast<const StructArray&>(target);
347+
348+
// Create comparators for each field
349+
std::vector<std::unique_ptr<ValueComparator>> field_comparators;
350+
field_comparators.reserve(struct_type.num_fields());
351+
352+
for (int i = 0; i < struct_type.num_fields(); ++i) {
353+
ARROW_ASSIGN_OR_RAISE(auto field_comparator,
354+
Create(*struct_type.field(i)->type(), *base_struct.field(i),
355+
*target_struct.field(i)));
356+
field_comparators.push_back(std::move(field_comparator));
357+
}
358+
359+
comparator_ = std::make_unique<StructValueComparator>(base_struct, target_struct,
360+
std::move(field_comparators));
361+
return Status::OK();
362+
}
363+
311364
Status Visit(const RunEndEncodedType& ree_type, const Array& base,
312365
const Array& target) {
313366
const auto& base_ree = checked_cast<const RunEndEncodedArray&>(base);

cpp/src/arrow/array/diff_test.cc

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,83 @@ TEST_F(DiffTest, CompareRandomStruct) {
816816
}
817817
}
818818

819+
TEST_F(DiffTest, StructFieldComparison) {
820+
// test struct field-by-field comparison
821+
auto type = struct_(
822+
{field("first", int32()), field("second", utf8()), field("third", int64())});
823+
824+
// first field differs
825+
base_ = ArrayFromJSON(type, R"([{"first": 1, "second": "a", "third": 100}])");
826+
target_ = ArrayFromJSON(type, R"([{"first": 2, "second": "a", "third": 100}])");
827+
DoDiff();
828+
AssertInsertIs("[false, false, true]");
829+
AssertRunLengthIs("[0, 0, 0]");
830+
831+
// second field differs
832+
base_ = ArrayFromJSON(type, R"([{"first": 1, "second": "a", "third": 100}])");
833+
target_ = ArrayFromJSON(type, R"([{"first": 1, "second": "b", "third": 100}])");
834+
DoDiff();
835+
AssertInsertIs("[false, false, true]");
836+
AssertRunLengthIs("[0, 0, 0]");
837+
838+
// third field differs
839+
base_ = ArrayFromJSON(type, R"([{"first": 1, "second": "a", "third": 100}])");
840+
target_ = ArrayFromJSON(type, R"([{"first": 1, "second": "a", "third": 200}])");
841+
DoDiff();
842+
AssertInsertIs("[false, false, true]");
843+
AssertRunLengthIs("[0, 0, 0]");
844+
845+
// all fields equal
846+
base_ = ArrayFromJSON(type, R"([{"first": 1, "second": "a", "third": 100}])");
847+
target_ = ArrayFromJSON(type, R"([{"first": 1, "second": "a", "third": 100}])");
848+
DoDiff();
849+
AssertInsertIs("[false]");
850+
AssertRunLengthIs("[1]");
851+
}
852+
853+
TEST_F(DiffTest, NestedStructComparison) {
854+
// test nested struct comparison
855+
auto inner_type = struct_({field("x", int32()), field("y", int32())});
856+
auto outer_type =
857+
struct_({field("id", int32()), field("inner", inner_type), field("name", utf8())});
858+
859+
// outer first field differs
860+
base_ = ArrayFromJSON(outer_type,
861+
R"([{"id": 1, "inner": {"x": 10, "y": 20}, "name": "test"}])");
862+
target_ = ArrayFromJSON(outer_type,
863+
R"([{"id": 2, "inner": {"x": 10, "y": 20}, "name": "test"}])");
864+
DoDiff();
865+
AssertInsertIs("[false, false, true]");
866+
AssertRunLengthIs("[0, 0, 0]");
867+
868+
// nested struct first field differs
869+
base_ = ArrayFromJSON(outer_type,
870+
R"([{"id": 1, "inner": {"x": 10, "y": 20}, "name": "test"}])");
871+
target_ = ArrayFromJSON(outer_type,
872+
R"([{"id": 1, "inner": {"x": 99, "y": 20}, "name": "test"}])");
873+
DoDiff();
874+
AssertInsertIs("[false, false, true]");
875+
AssertRunLengthIs("[0, 0, 0]");
876+
877+
// nested struct second field differs
878+
base_ = ArrayFromJSON(outer_type,
879+
R"([{"id": 1, "inner": {"x": 10, "y": 20}, "name": "test"}])");
880+
target_ = ArrayFromJSON(outer_type,
881+
R"([{"id": 1, "inner": {"x": 10, "y": 99}, "name": "test"}])");
882+
DoDiff();
883+
AssertInsertIs("[false, false, true]");
884+
AssertRunLengthIs("[0, 0, 0]");
885+
886+
// all equal including nested struct
887+
base_ = ArrayFromJSON(outer_type,
888+
R"([{"id": 1, "inner": {"x": 10, "y": 20}, "name": "test"}])");
889+
target_ = ArrayFromJSON(outer_type,
890+
R"([{"id": 1, "inner": {"x": 10, "y": 20}, "name": "test"}])");
891+
DoDiff();
892+
AssertInsertIs("[false]");
893+
AssertRunLengthIs("[1]");
894+
}
895+
819896
TEST_F(DiffTest, CompareHalfFloat) {
820897
auto first = ArrayFromJSON(float16(), "[1.1, 2.0, 2.5, 3.3]");
821898
auto second = ArrayFromJSON(float16(), "[1.1, 4.0, 3.5, 3.3]");

0 commit comments

Comments
 (0)