@@ -484,27 +484,66 @@ Comparator<CType>* GetComparator(CompareOperator op) {
484484}
485485
486486template <typename T, typename Fn, typename CType = typename TypeTraits<T>::CType>
487- std::shared_ptr<Array> CompareAndFilter (const CType* data, int64_t length, Fn&& fn) {
487+ std::shared_ptr<Array> CompareAndFilter (const std::shared_ptr<Array>& array, Fn&& fn) {
488+ using ArrayType = typename TypeTraits<T>::ArrayType;
489+ auto typed_array = checked_pointer_cast<ArrayType>(array);
490+
488491 std::vector<CType> filtered;
489- filtered.reserve (length);
490- std::copy_if (data, data + length, std::back_inserter (filtered), std::forward<Fn>(fn));
492+ filtered.reserve (array->length ());
493+
494+ for (int64_t i = 0 ; i < array->length (); ++i) {
495+ if (array->IsNull (i)) {
496+ // Nulls are filtered out (comparison with null is false)
497+ continue ;
498+ }
499+ CType value = typed_array->Value (i);
500+ if (fn (value)) {
501+ filtered.push_back (value);
502+ }
503+ }
504+
491505 std::shared_ptr<Array> filtered_array;
492506 ArrayFromVector<T, CType>(filtered, &filtered_array);
493507 return filtered_array;
494508}
495509
496- template <typename T, typename CType = typename TypeTraits<T>::CType>
497- std::shared_ptr<Array> CompareAndFilter (const CType* data, int64_t length, CType val,
510+ template <typename T, typename U,
511+ typename = std::enable_if_t <std::is_same_v<U, typename TypeTraits<T>::CType>>>
512+ std::shared_ptr<Array> CompareAndFilter (const std::shared_ptr<Array>& array, U val,
498513 CompareOperator op) {
514+ using CType = typename TypeTraits<T>::CType;
499515 auto cmp = GetComparator<CType>(op);
500- return CompareAndFilter<T>(data, length , [&](CType e) { return cmp (e, val); });
516+ return CompareAndFilter<T>(array , [&](CType e) { return cmp (e, val); });
501517}
502518
503- template <typename T, typename CType = typename TypeTraits<T>::CType>
504- std::shared_ptr<Array> CompareAndFilter (const CType* data, int64_t length,
505- const CType* other, CompareOperator op) {
519+ template <typename T>
520+ std::shared_ptr<Array> CompareAndFilter (const std::shared_ptr<Array>& lhs,
521+ const std::shared_ptr<Array>& rhs,
522+ CompareOperator op) {
523+ using ArrayType = typename TypeTraits<T>::ArrayType;
524+ using CType = typename TypeTraits<T>::CType;
525+ auto lhs_typed = checked_pointer_cast<ArrayType>(lhs);
526+ auto rhs_typed = checked_pointer_cast<ArrayType>(rhs);
506527 auto cmp = GetComparator<CType>(op);
507- return CompareAndFilter<T>(data, length, [&](CType e) { return cmp (e, *other++); });
528+
529+ std::vector<CType> filtered;
530+ filtered.reserve (lhs->length ());
531+
532+ for (int64_t i = 0 ; i < lhs->length (); ++i) {
533+ // Skip if either element is null
534+ if (lhs->IsNull (i) || rhs->IsNull (i)) {
535+ continue ;
536+ }
537+ CType lhs_value = lhs_typed->Value (i);
538+ CType rhs_value = rhs_typed->Value (i);
539+ if (cmp (lhs_value, rhs_value)) {
540+ filtered.push_back (lhs_value);
541+ }
542+ }
543+
544+ std::shared_ptr<Array> filtered_array;
545+ ArrayFromVector<T, CType>(filtered, &filtered_array);
546+ return filtered_array;
508547}
509548
510549TYPED_TEST (TestFilterKernelWithNumeric, CompareScalarAndFilterRandomNumeric) {
@@ -515,9 +554,10 @@ TYPED_TEST(TestFilterKernelWithNumeric, CompareScalarAndFilterRandomNumeric) {
515554 auto rand = random::RandomArrayGenerator (kRandomSeed );
516555 for (size_t i = 3 ; i < 10 ; i++) {
517556 const int64_t length = static_cast <int64_t >(1ULL << i);
518- // TODO(bkietz) rewrite with some nulls
519- auto array =
520- checked_pointer_cast<ArrayType>(rand.Numeric <TypeParam>(length, 0 , 100 , 0 ));
557+ // Use deterministic null probabilities: 0.0, 0.25, 0.4, 0.5, 0.571, 0.625, 0.667
558+ double null_probability = static_cast <double >(i - 3 ) / i;
559+ auto array = checked_pointer_cast<ArrayType>(
560+ rand.Numeric <TypeParam>(length, 0 , 100 , null_probability));
521561 CType c_fifty = 50 ;
522562 auto fifty = std::make_shared<ScalarType>(c_fifty);
523563 for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) {
@@ -527,8 +567,7 @@ TYPED_TEST(TestFilterKernelWithNumeric, CompareScalarAndFilterRandomNumeric) {
527567 ASSERT_OK_AND_ASSIGN (Datum filtered, Filter (array, selection));
528568 auto filtered_array = filtered.make_array ();
529569 ValidateOutput (*filtered_array);
530- auto expected =
531- CompareAndFilter<TypeParam>(array->raw_values (), array->length (), c_fifty, op);
570+ auto expected = CompareAndFilter<TypeParam>(array, c_fifty, op);
532571 ASSERT_ARRAYS_EQUAL (*filtered_array, *expected);
533572 }
534573 }
@@ -540,18 +579,20 @@ TYPED_TEST(TestFilterKernelWithNumeric, CompareArrayAndFilterRandomNumeric) {
540579 auto rand = random::RandomArrayGenerator (kRandomSeed );
541580 for (size_t i = 3 ; i < 10 ; i++) {
542581 const int64_t length = static_cast <int64_t >(1ULL << i);
582+ // Use deterministic null probabilities with different values for lhs and rhs
583+ double null_probability_lhs = static_cast <double >(i - 3 ) / i;
584+ double null_probability_rhs = static_cast <double >(i) / (i + 7 );
543585 auto lhs = checked_pointer_cast<ArrayType>(
544- rand.Numeric <TypeParam>(length, 0 , 100 , /* null_probability= */ 0.0 ));
586+ rand.Numeric <TypeParam>(length, 0 , 100 , null_probability_lhs ));
545587 auto rhs = checked_pointer_cast<ArrayType>(
546- rand.Numeric <TypeParam>(length, 0 , 100 , /* null_probability= */ 0.0 ));
588+ rand.Numeric <TypeParam>(length, 0 , 100 , null_probability_rhs ));
547589 for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) {
548590 ASSERT_OK_AND_ASSIGN (Datum selection,
549591 CallFunction (CompareOperatorToFunctionName (op), {lhs, rhs}));
550592 ASSERT_OK_AND_ASSIGN (Datum filtered, Filter (lhs, selection));
551593 auto filtered_array = filtered.make_array ();
552594 ValidateOutput (*filtered_array);
553- auto expected = CompareAndFilter<TypeParam>(lhs->raw_values (), lhs->length (),
554- rhs->raw_values (), op);
595+ auto expected = CompareAndFilter<TypeParam>(lhs, rhs, op);
555596 ASSERT_ARRAYS_EQUAL (*filtered_array, *expected);
556597 }
557598 }
@@ -565,8 +606,10 @@ TYPED_TEST(TestFilterKernelWithNumeric, ScalarInRangeAndFilterRandomNumeric) {
565606 auto rand = random::RandomArrayGenerator (kRandomSeed );
566607 for (size_t i = 3 ; i < 10 ; i++) {
567608 const int64_t length = static_cast <int64_t >(1ULL << i);
609+ // Use deterministic null probabilities: 0.0, 0.25, 0.4, 0.5, 0.571, 0.625, 0.667
610+ double null_probability = static_cast <double >(i - 3 ) / i;
568611 auto array = checked_pointer_cast<ArrayType>(
569- rand.Numeric <TypeParam>(length, 0 , 100 , /* null_probability= */ 0.0 ));
612+ rand.Numeric <TypeParam>(length, 0 , 100 , null_probability));
570613 CType c_fifty = 50 , c_hundred = 100 ;
571614 auto fifty = std::make_shared<ScalarType>(c_fifty);
572615 auto hundred = std::make_shared<ScalarType>(c_hundred);
@@ -579,8 +622,7 @@ TYPED_TEST(TestFilterKernelWithNumeric, ScalarInRangeAndFilterRandomNumeric) {
579622 auto filtered_array = filtered.make_array ();
580623 ValidateOutput (*filtered_array);
581624 auto expected = CompareAndFilter<TypeParam>(
582- array->raw_values (), array->length (),
583- [&](CType e) { return (e > c_fifty) && (e < c_hundred); });
625+ array, [&](CType e) { return (e > c_fifty) && (e < c_hundred); });
584626 ASSERT_ARRAYS_EQUAL (*filtered_array, *expected);
585627 }
586628}
0 commit comments