|
16 | 16 | // under the License. |
17 | 17 |
|
18 | 18 | #include <gtest/gtest.h> |
| 19 | +#include <memory> |
| 20 | +#include <vector> |
19 | 21 |
|
20 | 22 | #include "arrow/acero/exec_plan.h" |
21 | 23 | #include "arrow/acero/map_node.h" |
22 | 24 | #include "arrow/acero/options.h" |
23 | 25 | #include "arrow/acero/test_nodes.h" |
24 | 26 | #include "arrow/array/builder_base.h" |
| 27 | +#include "arrow/array/builder_primitive.h" |
25 | 28 | #include "arrow/array/concatenate.h" |
26 | 29 | #include "arrow/compute/ordering.h" |
| 30 | +#include "arrow/dataset/dataset.h" |
| 31 | +#include "arrow/dataset/scanner.h" |
| 32 | +#include "arrow/record_batch.h" |
27 | 33 | #include "arrow/result.h" |
28 | 34 | #include "arrow/scalar.h" |
| 35 | +#include "arrow/status.h" |
29 | 36 | #include "arrow/table.h" |
30 | 37 | #include "arrow/testing/generator.h" |
31 | 38 | #include "arrow/testing/gtest_util.h" |
@@ -83,4 +90,107 @@ TEST(SortedMergeNode, Basic) { |
83 | 90 | AssertArraysEqual(*expected_ts, *output_ts); |
84 | 91 | } |
85 | 92 |
|
| 93 | +TEST(SortedMergeNode, TestSortedMergeTwoInputsWithBool) { |
| 94 | + const int64_t row_count = (16 << 10); // 16k rows per input |
| 95 | + |
| 96 | + // Create schema with int column A and bool column B |
| 97 | + auto test_schema = arrow::schema( |
| 98 | + {arrow::field("col_a", arrow::int32()), arrow::field("col_b", arrow::boolean())}); |
| 99 | + |
| 100 | + // Helper lambda to create table with specific pattern |
| 101 | + auto create_test_scanner = [&](int64_t cnt, int offset) -> arrow::Result<Declaration> { |
| 102 | + // Create column A (int) - values from offset to offset+cnt-1 |
| 103 | + arrow::Int32Builder col_a_builder; |
| 104 | + std::vector<int32_t> col_a_values; |
| 105 | + col_a_values.reserve(cnt); |
| 106 | + for (int64_t i = 0; i < cnt; ++i) { |
| 107 | + col_a_values.push_back(static_cast<int32_t>(offset + i)); |
| 108 | + } |
| 109 | + ARROW_RETURN_NOT_OK(col_a_builder.AppendValues(col_a_values)); |
| 110 | + std::shared_ptr<arrow::Array> col_a_arr; |
| 111 | + ARROW_RETURN_NOT_OK(col_a_builder.Finish(&col_a_arr)); |
| 112 | + |
| 113 | + // Create column B (bool) - pattern: true if col_a % 5 == 0, false otherwise |
| 114 | + arrow::BooleanBuilder col_b_builder; |
| 115 | + for (int64_t i = 0; i < cnt; ++i) { |
| 116 | + int32_t a_value = offset + i; |
| 117 | + bool b_value = (a_value % 5 == 0); |
| 118 | + ARROW_RETURN_NOT_OK(col_b_builder.Append(b_value)); |
| 119 | + } |
| 120 | + std::shared_ptr<arrow::Array> col_b_arr; |
| 121 | + ARROW_RETURN_NOT_OK(col_b_builder.Finish(&col_b_arr)); |
| 122 | + |
| 123 | + auto table = arrow::Table::Make(test_schema, {col_a_arr, col_b_arr}); |
| 124 | + auto table_source = |
| 125 | + Declaration("table_source", TableSourceNodeOptions(table, row_count / 16)); |
| 126 | + return table_source; |
| 127 | + }; |
| 128 | + |
| 129 | + ASSERT_OK_AND_ASSIGN(auto source1, create_test_scanner(row_count, 0)); |
| 130 | + ASSERT_OK_AND_ASSIGN(auto source2, create_test_scanner(row_count, 8192)); |
| 131 | + |
| 132 | + // Create sorted merge by column A |
| 133 | + auto ops = OrderByNodeOptions(compute::Ordering({compute::SortKey("col_a")})); |
| 134 | + Declaration sorted_merge{"sorted_merge", {source1, source2}, ops}; |
| 135 | + |
| 136 | + // Execute plan and collect result |
| 137 | + ASSERT_OK_AND_ASSIGN(auto result_table, |
| 138 | + arrow::acero::DeclarationToTable(sorted_merge, false)); |
| 139 | + |
| 140 | + ASSERT_TRUE(result_table != nullptr); |
| 141 | + |
| 142 | + // Verify results |
| 143 | + auto col_a = result_table->GetColumnByName("col_a"); |
| 144 | + auto col_b = result_table->GetColumnByName("col_b"); |
| 145 | + ASSERT_TRUE(col_a != nullptr); |
| 146 | + ASSERT_TRUE(col_b != nullptr); |
| 147 | + |
| 148 | + // Verify sorting and bool values |
| 149 | + int32_t last_a_value = std::numeric_limits<int32_t>::min(); |
| 150 | + int64_t total_rows_checked = 0; |
| 151 | + int64_t true_count = 0; |
| 152 | + int64_t false_count = 0; |
| 153 | + |
| 154 | + for (int i = 0; i < col_a->num_chunks(); i++) { |
| 155 | + auto a_chunk = std::static_pointer_cast<arrow::Int32Array>(col_a->chunk(i)); |
| 156 | + auto b_chunk = std::static_pointer_cast<arrow::BooleanArray>(col_b->chunk(i)); |
| 157 | + |
| 158 | + ASSERT_EQ(a_chunk->length(), b_chunk->length()) |
| 159 | + << "Column A and B must have same length in chunk " << i; |
| 160 | + |
| 161 | + for (int64_t j = 0; j < a_chunk->length(); j++) { |
| 162 | + ASSERT_FALSE(a_chunk->IsNull(j)) << "Column A should not have null values"; |
| 163 | + ASSERT_FALSE(b_chunk->IsNull(j)) << "Column B should not have null values"; |
| 164 | + |
| 165 | + int32_t a_value = a_chunk->Value(j); |
| 166 | + bool b_value = b_chunk->Value(j); |
| 167 | + |
| 168 | + // Verify sorting by column A |
| 169 | + ASSERT_GE(a_value, last_a_value) |
| 170 | + << "Values not sorted at chunk " << i << ", row " << j |
| 171 | + << ": current=" << a_value << ", last=" << last_a_value; |
| 172 | + last_a_value = a_value; |
| 173 | + |
| 174 | + // Verify bool value correctness: should be true if a_value % 3 == 0 |
| 175 | + bool expected_b_value = (a_value % 5 == 0); |
| 176 | + ASSERT_EQ(b_value, expected_b_value) |
| 177 | + << "Bool value incorrect at chunk " << i << ", row " << j |
| 178 | + << ": col_a=" << a_value << ", col_b=" << b_value |
| 179 | + << ", expected=" << expected_b_value; |
| 180 | + |
| 181 | + if (b_value) { |
| 182 | + true_count++; |
| 183 | + } else { |
| 184 | + false_count++; |
| 185 | + } |
| 186 | + total_rows_checked++; |
| 187 | + } |
| 188 | + } |
| 189 | + |
| 190 | + ASSERT_EQ(last_a_value, 24575); |
| 191 | + |
| 192 | + ASSERT_EQ(total_rows_checked, row_count * 2) |
| 193 | + << "Expected " << row_count << " unique rows after merge"; |
| 194 | +} |
| 195 | + |
86 | 196 | } // namespace arrow::acero |
0 commit comments