Skip to content

Commit 16a2aca

Browse files
committed
[C++] Test filter operations with random null probabilities
1 parent 35717a7 commit 16a2aca

File tree

1 file changed

+64
-22
lines changed

1 file changed

+64
-22
lines changed

cpp/src/arrow/compute/kernels/vector_selection_test.cc

Lines changed: 64 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -484,27 +484,66 @@ Comparator<CType>* GetComparator(CompareOperator op) {
484484
}
485485

486486
template <typename T, typename Fn, typename CType = typename TypeTraits<T>::CType>
487-
std::shared_ptr<Array> CompareAndFilter(const CType* data, int64_t length, Fn&& fn) {
487+
std::shared_ptr<Array> CompareAndFilter(const std::shared_ptr<Array>& array, Fn&& fn) {
488+
using ArrayType = typename TypeTraits<T>::ArrayType;
489+
auto typed_array = checked_pointer_cast<ArrayType>(array);
490+
488491
std::vector<CType> filtered;
489-
filtered.reserve(length);
490-
std::copy_if(data, data + length, std::back_inserter(filtered), std::forward<Fn>(fn));
492+
filtered.reserve(array->length());
493+
494+
for (int64_t i = 0; i < array->length(); ++i) {
495+
if (array->IsNull(i)) {
496+
// Nulls are filtered out (comparison with null is false)
497+
continue;
498+
}
499+
CType value = typed_array->Value(i);
500+
if (fn(value)) {
501+
filtered.push_back(value);
502+
}
503+
}
504+
491505
std::shared_ptr<Array> filtered_array;
492506
ArrayFromVector<T, CType>(filtered, &filtered_array);
493507
return filtered_array;
494508
}
495509

496-
template <typename T, typename CType = typename TypeTraits<T>::CType>
497-
std::shared_ptr<Array> CompareAndFilter(const CType* data, int64_t length, CType val,
510+
template <typename T, typename U,
511+
typename = std::enable_if_t<std::is_same_v<U, typename TypeTraits<T>::CType>>>
512+
std::shared_ptr<Array> CompareAndFilter(const std::shared_ptr<Array>& array, U val,
498513
CompareOperator op) {
514+
using CType = typename TypeTraits<T>::CType;
499515
auto cmp = GetComparator<CType>(op);
500-
return CompareAndFilter<T>(data, length, [&](CType e) { return cmp(e, val); });
516+
return CompareAndFilter<T>(array, [&](CType e) { return cmp(e, val); });
501517
}
502518

503-
template <typename T, typename CType = typename TypeTraits<T>::CType>
504-
std::shared_ptr<Array> CompareAndFilter(const CType* data, int64_t length,
505-
const CType* other, CompareOperator op) {
519+
template <typename T>
520+
std::shared_ptr<Array> CompareAndFilter(const std::shared_ptr<Array>& lhs,
521+
const std::shared_ptr<Array>& rhs,
522+
CompareOperator op) {
523+
using ArrayType = typename TypeTraits<T>::ArrayType;
524+
using CType = typename TypeTraits<T>::CType;
525+
auto lhs_typed = checked_pointer_cast<ArrayType>(lhs);
526+
auto rhs_typed = checked_pointer_cast<ArrayType>(rhs);
506527
auto cmp = GetComparator<CType>(op);
507-
return CompareAndFilter<T>(data, length, [&](CType e) { return cmp(e, *other++); });
528+
529+
std::vector<CType> filtered;
530+
filtered.reserve(lhs->length());
531+
532+
for (int64_t i = 0; i < lhs->length(); ++i) {
533+
// Skip if either element is null
534+
if (lhs->IsNull(i) || rhs->IsNull(i)) {
535+
continue;
536+
}
537+
CType lhs_value = lhs_typed->Value(i);
538+
CType rhs_value = rhs_typed->Value(i);
539+
if (cmp(lhs_value, rhs_value)) {
540+
filtered.push_back(lhs_value);
541+
}
542+
}
543+
544+
std::shared_ptr<Array> filtered_array;
545+
ArrayFromVector<T, CType>(filtered, &filtered_array);
546+
return filtered_array;
508547
}
509548

510549
TYPED_TEST(TestFilterKernelWithNumeric, CompareScalarAndFilterRandomNumeric) {
@@ -515,9 +554,10 @@ TYPED_TEST(TestFilterKernelWithNumeric, CompareScalarAndFilterRandomNumeric) {
515554
auto rand = random::RandomArrayGenerator(kRandomSeed);
516555
for (size_t i = 3; i < 10; i++) {
517556
const int64_t length = static_cast<int64_t>(1ULL << i);
518-
// TODO(bkietz) rewrite with some nulls
519-
auto array =
520-
checked_pointer_cast<ArrayType>(rand.Numeric<TypeParam>(length, 0, 100, 0));
557+
// Use deterministic null probabilities: 0.0, 0.25, 0.4, 0.5, 0.571, 0.625, 0.667
558+
double null_probability = static_cast<double>(i - 3) / i;
559+
auto array = checked_pointer_cast<ArrayType>(
560+
rand.Numeric<TypeParam>(length, 0, 100, null_probability));
521561
CType c_fifty = 50;
522562
auto fifty = std::make_shared<ScalarType>(c_fifty);
523563
for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) {
@@ -527,8 +567,7 @@ TYPED_TEST(TestFilterKernelWithNumeric, CompareScalarAndFilterRandomNumeric) {
527567
ASSERT_OK_AND_ASSIGN(Datum filtered, Filter(array, selection));
528568
auto filtered_array = filtered.make_array();
529569
ValidateOutput(*filtered_array);
530-
auto expected =
531-
CompareAndFilter<TypeParam>(array->raw_values(), array->length(), c_fifty, op);
570+
auto expected = CompareAndFilter<TypeParam>(array, c_fifty, op);
532571
ASSERT_ARRAYS_EQUAL(*filtered_array, *expected);
533572
}
534573
}
@@ -540,18 +579,20 @@ TYPED_TEST(TestFilterKernelWithNumeric, CompareArrayAndFilterRandomNumeric) {
540579
auto rand = random::RandomArrayGenerator(kRandomSeed);
541580
for (size_t i = 3; i < 10; i++) {
542581
const int64_t length = static_cast<int64_t>(1ULL << i);
582+
// Use deterministic null probabilities with different values for lhs and rhs
583+
double null_probability_lhs = static_cast<double>(i - 3) / i;
584+
double null_probability_rhs = static_cast<double>(i) / (i + 7);
543585
auto lhs = checked_pointer_cast<ArrayType>(
544-
rand.Numeric<TypeParam>(length, 0, 100, /*null_probability=*/0.0));
586+
rand.Numeric<TypeParam>(length, 0, 100, null_probability_lhs));
545587
auto rhs = checked_pointer_cast<ArrayType>(
546-
rand.Numeric<TypeParam>(length, 0, 100, /*null_probability=*/0.0));
588+
rand.Numeric<TypeParam>(length, 0, 100, null_probability_rhs));
547589
for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) {
548590
ASSERT_OK_AND_ASSIGN(Datum selection,
549591
CallFunction(CompareOperatorToFunctionName(op), {lhs, rhs}));
550592
ASSERT_OK_AND_ASSIGN(Datum filtered, Filter(lhs, selection));
551593
auto filtered_array = filtered.make_array();
552594
ValidateOutput(*filtered_array);
553-
auto expected = CompareAndFilter<TypeParam>(lhs->raw_values(), lhs->length(),
554-
rhs->raw_values(), op);
595+
auto expected = CompareAndFilter<TypeParam>(lhs, rhs, op);
555596
ASSERT_ARRAYS_EQUAL(*filtered_array, *expected);
556597
}
557598
}
@@ -565,8 +606,10 @@ TYPED_TEST(TestFilterKernelWithNumeric, ScalarInRangeAndFilterRandomNumeric) {
565606
auto rand = random::RandomArrayGenerator(kRandomSeed);
566607
for (size_t i = 3; i < 10; i++) {
567608
const int64_t length = static_cast<int64_t>(1ULL << i);
609+
// Use deterministic null probabilities: 0.0, 0.25, 0.4, 0.5, 0.571, 0.625, 0.667
610+
double null_probability = static_cast<double>(i - 3) / i;
568611
auto array = checked_pointer_cast<ArrayType>(
569-
rand.Numeric<TypeParam>(length, 0, 100, /*null_probability=*/0.0));
612+
rand.Numeric<TypeParam>(length, 0, 100, null_probability));
570613
CType c_fifty = 50, c_hundred = 100;
571614
auto fifty = std::make_shared<ScalarType>(c_fifty);
572615
auto hundred = std::make_shared<ScalarType>(c_hundred);
@@ -579,8 +622,7 @@ TYPED_TEST(TestFilterKernelWithNumeric, ScalarInRangeAndFilterRandomNumeric) {
579622
auto filtered_array = filtered.make_array();
580623
ValidateOutput(*filtered_array);
581624
auto expected = CompareAndFilter<TypeParam>(
582-
array->raw_values(), array->length(),
583-
[&](CType e) { return (e > c_fifty) && (e < c_hundred); });
625+
array, [&](CType e) { return (e > c_fifty) && (e < c_hundred); });
584626
ASSERT_ARRAYS_EQUAL(*filtered_array, *expected);
585627
}
586628
}

0 commit comments

Comments
 (0)