Skip to content

Commit 1567be0

Browse files
mcshawn10pitrou
andauthored
GH-26648: [C++] Optimize union equality comparison (#45384)
### Rationale for this change #26648 proposes an optimization in checking sparse array equality by detecting contiguous runs, this PR implements that change ### What changes are included in this PR? previously, sparse array comparison was checked one by one, in this change, contiguous runs are detected and compared by checking equality of current and previous child_ids ### Are these changes tested? already covered by existing unit tests ### Are there any user-facing changes? no * GitHub Issue: #26648 Lead-authored-by: shawn <[email protected]> Co-authored-by: Antoine Pitrou <[email protected]> Signed-off-by: Antoine Pitrou <[email protected]>
1 parent 31747f0 commit 1567be0

File tree

2 files changed

+66
-8
lines changed

2 files changed

+66
-8
lines changed

cpp/src/arrow/array/array_union_test.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,36 @@ TEST(TestSparseUnionArray, Validate) {
166166
ASSERT_RAISES(Invalid, arr->ValidateFull());
167167
}
168168

169+
TEST(TestSparseUnionArray, Comparison) {
170+
auto ints1 = ArrayFromJSON(int32(), "[1, 2, 3, 4, 5, 6]");
171+
auto ints2 = ArrayFromJSON(int32(), "[1, 2, -3, 4, -5, 6]");
172+
auto strs1 = ArrayFromJSON(utf8(), R"(["a", "b", "c", "d", "e", "f"])");
173+
auto strs2 = ArrayFromJSON(utf8(), R"(["a", "*", "c", "d", "e", "*"])");
174+
std::vector<int8_t> type_codes{8, 42};
175+
176+
auto check_equality = [&](const std::string& type_ids_json1,
177+
const std::string& type_ids_json2, bool expected_equals) {
178+
auto type_ids1 = ArrayFromJSON(int8(), type_ids_json1);
179+
auto type_ids2 = ArrayFromJSON(int8(), type_ids_json2);
180+
ASSERT_OK_AND_ASSIGN(auto arr1,
181+
SparseUnionArray::Make(*type_ids1, {ints1, strs1}, type_codes));
182+
ASSERT_OK_AND_ASSIGN(auto arr2,
183+
SparseUnionArray::Make(*type_ids2, {ints2, strs2}, type_codes));
184+
ASSERT_EQ(arr1->Equals(arr2), expected_equals);
185+
ASSERT_EQ(arr2->Equals(arr1), expected_equals);
186+
};
187+
188+
// Same type ids
189+
check_equality("[8, 8, 42, 42, 42, 8]", "[8, 8, 42, 42, 42, 8]", true);
190+
check_equality("[8, 8, 42, 42, 42, 42]", "[8, 8, 42, 42, 42, 42]", false);
191+
check_equality("[8, 8, 8, 42, 42, 8]", "[8, 8, 8, 42, 42, 8]", false);
192+
check_equality("[8, 42, 42, 42, 42, 8]", "[8, 42, 42, 42, 42, 8]", false);
193+
194+
// Different type ids
195+
check_equality("[42, 8, 42, 42, 42, 8]", "[8, 8, 42, 42, 42, 8]", false);
196+
check_equality("[8, 8, 42, 42, 42, 8]", "[8, 8, 42, 42, 42, 42]", false);
197+
}
198+
169199
// -------------------------------------------------------------------------
170200
// Tests for MakeDense and MakeSparse
171201

cpp/src/arrow/compare.cc

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -381,21 +381,49 @@ class RangeDataEqualsImpl {
381381
const int8_t* right_codes = right_.GetValues<int8_t>(1);
382382

383383
// Unions don't have a null bitmap
384+
int64_t run_start = 0; // Start index of the current run
385+
384386
for (int64_t i = 0; i < range_length_; ++i) {
385-
const auto type_id = left_codes[left_start_idx_ + i];
386-
if (type_id != right_codes[right_start_idx_ + i]) {
387+
const auto current_type_id = left_codes[left_start_idx_ + i];
388+
389+
if (current_type_id != right_codes[right_start_idx_ + i]) {
387390
result_ = false;
388391
break;
389392
}
390-
const auto child_num = child_ids[type_id];
391-
// XXX can we instead detect runs of same-child union values?
393+
// Check if the current element breaks the run
394+
if (i > 0 && current_type_id != left_codes[left_start_idx_ + i - 1]) {
395+
// Compare the previous run
396+
const auto previous_child_num = child_ids[left_codes[left_start_idx_ + i - 1]];
397+
int64_t run_length = i - run_start;
398+
399+
RangeDataEqualsImpl impl(
400+
options_, floating_approximate_, *left_.child_data[previous_child_num],
401+
*right_.child_data[previous_child_num],
402+
left_start_idx_ + left_.offset + run_start,
403+
right_start_idx_ + right_.offset + run_start, run_length);
404+
405+
if (!impl.Compare()) {
406+
result_ = false;
407+
break;
408+
}
409+
410+
// Start a new run
411+
run_start = i;
412+
}
413+
}
414+
415+
// Handle the final run
416+
if (result_) {
417+
const auto final_child_num = child_ids[left_codes[left_start_idx_ + run_start]];
418+
int64_t final_run_length = range_length_ - run_start;
419+
392420
RangeDataEqualsImpl impl(
393-
options_, floating_approximate_, *left_.child_data[child_num],
394-
*right_.child_data[child_num], left_start_idx_ + left_.offset + i,
395-
right_start_idx_ + right_.offset + i, 1);
421+
options_, floating_approximate_, *left_.child_data[final_child_num],
422+
*right_.child_data[final_child_num], left_start_idx_ + left_.offset + run_start,
423+
right_start_idx_ + right_.offset + run_start, final_run_length);
424+
396425
if (!impl.Compare()) {
397426
result_ = false;
398-
break;
399427
}
400428
}
401429
return Status::OK();

0 commit comments

Comments
 (0)