@@ -138,13 +138,45 @@ inline void VisitRawValuesInline(const ArraySpan& values,
138138}
139139
140140template <typename ArrowType>
141- class ArrayCompareSorter {
141+ class ArraySorterMixin {
142+ protected:
142143 using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
143144 using GetView = GetViewType<ArrowType>;
145+ using ValueType =
146+ decltype (GetView::LogicalValue(std::declval<ArrayType>().GetView(uint64_t (0 ))));
147+
148+ void MarkDuplicates (NullPartitionResult p, const ArrayType& values,
149+ int64_t offset) const {
150+ // TODO GH-45193: what about NaNs vs. actual nulls?
151+ if (p.non_nulls_end != p.non_nulls_begin ) {
152+ auto it = p.non_nulls_begin ;
153+ ValueType prev_value = GetView::LogicalValue (values.GetView (*it - offset));
154+ while (++it < p.non_nulls_end ) {
155+ ValueType curr_value = GetView::LogicalValue (values.GetView (*it - offset));
156+ if (curr_value == prev_value) {
157+ *it |= kDuplicateMask ;
158+ }
159+ prev_value = curr_value;
160+ }
161+ }
162+ if (p.nulls_end != p.nulls_begin ) {
163+ auto it = p.nulls_begin ;
164+ while (++it < p.nulls_end ) {
165+ *it |= kDuplicateMask ;
166+ }
167+ }
168+ }
169+ };
170+
171+ template <typename ArrowType, typename Base = ArraySorterMixin<ArrowType>>
172+ class ArrayCompareSorter : Base {
173+ using ArrayType = typename Base::ArrayType;
174+ using GetView = typename Base::GetView;
144175
145176 public:
146177 Result<NullPartitionResult> operator ()(uint64_t * indices_begin, uint64_t * indices_end,
147178 const Array& array, int64_t offset,
179+ bool mark_duplicates,
148180 const ArraySortOptions& options, ExecContext*) {
149181 const auto & values = checked_cast<const ArrayType&>(array);
150182
@@ -169,6 +201,9 @@ class ArrayCompareSorter {
169201 return rhs < lhs;
170202 });
171203 }
204+ if (mark_duplicates) {
205+ Base::MarkDuplicates (p, values, offset);
206+ }
172207 return p;
173208 }
174209};
@@ -178,6 +213,7 @@ class ArrayCompareSorter<DictionaryType> {
178213 public:
179214 Result<NullPartitionResult> operator ()(uint64_t * indices_begin, uint64_t * indices_end,
180215 const Array& array, int64_t offset,
216+ bool mark_duplicates,
181217 const ArraySortOptions& options,
182218 ExecContext* ctx) {
183219 const auto & dict_array = checked_cast<const DictionaryArray&>(array);
@@ -220,7 +256,8 @@ class ArrayCompareSorter<DictionaryType> {
220256 DCHECK_EQ (decoded_ranks->length (), dict_array.length ());
221257 ARROW_ASSIGN_OR_RAISE (auto rank_sorter, GetArraySorter (*decoded_ranks->type ()));
222258
223- return rank_sorter (indices_begin, indices_end, *decoded_ranks, offset, options, ctx);
259+ return rank_sorter (indices_begin, indices_end, *decoded_ranks, offset,
260+ mark_duplicates, options, ctx);
224261 }
225262
226263 private:
@@ -267,17 +304,22 @@ class ArrayCompareSorter<StructType> {
267304 public:
268305 Result<NullPartitionResult> operator ()(uint64_t * indices_begin, uint64_t * indices_end,
269306 const Array& array, int64_t offset,
307+ bool mark_duplicates,
270308 const ArraySortOptions& options,
271309 ExecContext* ctx) {
272310 const auto & struct_array = checked_cast<const StructArray&>(array);
311+ if (mark_duplicates) {
312+ // TODO (but rank currently doesn't support struct arrays)
313+ return Status::NotImplemented (" Marking duplicates not supported for StructArray" );
314+ }
273315 return SortStructArray (ctx, indices_begin, indices_end, struct_array, options.order ,
274316 options.null_placement );
275317 }
276318};
277319
278- template <typename ArrowType>
279- class ArrayCountSorter {
280- using ArrayType = typename TypeTraits<ArrowType> ::ArrayType;
320+ template <typename ArrowType, typename Base = ArraySorterMixin<ArrowType> >
321+ class ArrayCountSorter : Base {
322+ using ArrayType = typename Base ::ArrayType;
281323 using c_type = typename ArrowType::c_type;
282324
283325 public:
@@ -293,16 +335,24 @@ class ArrayCountSorter {
293335
294336 Result<NullPartitionResult> operator ()(uint64_t * indices_begin, uint64_t * indices_end,
295337 const Array& array, int64_t offset,
338+ bool mark_duplicates,
296339 const ArraySortOptions& options,
297340 ExecContext*) const {
298341 const auto & values = checked_cast<const ArrayType&>(array);
342+ NullPartitionResult p;
299343
300344 // 32bit counter performs much better than 64bit one
301345 if (values.length () < (1LL << 32 )) {
302- return SortInternal<uint32_t >(indices_begin, indices_end, values, offset, options);
346+ p = SortInternal<uint32_t >(indices_begin, indices_end, values, offset, options);
303347 } else {
304- return SortInternal<uint64_t >(indices_begin, indices_end, values, offset, options);
348+ p = SortInternal<uint64_t >(indices_begin, indices_end, values, offset, options);
349+ }
350+ if (mark_duplicates) {
351+ // Perhaps we can mark duplicates slightly faster by doing it directly
352+ // in EmitIndices()? It probably doesn't matter for real-world tasks.
353+ Base::MarkDuplicates (p, values, offset);
305354 }
355+ return p;
306356 }
307357
308358 private:
@@ -378,6 +428,7 @@ class ArrayCountSorter<BooleanType> {
378428
379429 Result<NullPartitionResult> operator ()(uint64_t * indices_begin, uint64_t * indices_end,
380430 const Array& array, int64_t offset,
431+ bool mark_duplicates,
381432 const ArraySortOptions& options, ExecContext*) {
382433 const auto & values = checked_cast<const BooleanArray&>(array);
383434
@@ -405,9 +456,31 @@ class ArrayCountSorter<BooleanType> {
405456 }
406457
407458 int64_t index = offset;
408- VisitRawValuesInline (
409- *values.data (), [&](bool v) { p.non_nulls_begin [counts[v]++] = index++; },
410- [&]() { p.nulls_begin [counts[2 ]++] = index++; });
459+ if (mark_duplicates) {
460+ std::array<bool , 3 > seen{}; // false, true, null (like `counts`)
461+ VisitRawValuesInline (
462+ *values.data (),
463+ [&](bool v) {
464+ if (seen[v]) {
465+ p.non_nulls_begin [counts[v]++] = index++ | kDuplicateMask ;
466+ } else {
467+ p.non_nulls_begin [counts[v]++] = index++;
468+ seen[v] = true ;
469+ }
470+ },
471+ [&]() {
472+ if (seen[2 ]) {
473+ p.nulls_begin [counts[2 ]++] = index++ | kDuplicateMask ;
474+ } else {
475+ p.nulls_begin [counts[2 ]++] = index++;
476+ seen[2 ] = true ;
477+ }
478+ });
479+ } else {
480+ VisitRawValuesInline (
481+ *values.data (), [&](bool v) { p.non_nulls_begin [counts[v]++] = index++; },
482+ [&]() { p.nulls_begin [counts[2 ]++] = index++; });
483+ }
411484 return p;
412485 }
413486};
@@ -423,6 +496,7 @@ class ArrayCountOrCompareSorter {
423496 public:
424497 Result<NullPartitionResult> operator ()(uint64_t * indices_begin, uint64_t * indices_end,
425498 const Array& array, int64_t offset,
499+ bool mark_duplicates,
426500 const ArraySortOptions& options,
427501 ExecContext* ctx) {
428502 const auto & values = checked_cast<const ArrayType&>(array);
@@ -436,11 +510,13 @@ class ArrayCountOrCompareSorter {
436510 if (static_cast <uint64_t >(max) - static_cast <uint64_t >(min) <=
437511 countsort_max_range_) {
438512 count_sorter_.SetMinMax (min, max);
439- return count_sorter_ (indices_begin, indices_end, values, offset, options, ctx);
513+ return count_sorter_ (indices_begin, indices_end, values, offset, mark_duplicates,
514+ options, ctx);
440515 }
441516 }
442517
443- return compare_sorter_ (indices_begin, indices_end, values, offset, options, ctx);
518+ return compare_sorter_ (indices_begin, indices_end, values, offset, mark_duplicates,
519+ options, ctx);
444520 }
445521
446522 private:
@@ -464,9 +540,17 @@ class ArrayNullSorter {
464540 public:
465541 Result<NullPartitionResult> operator ()(uint64_t * indices_begin, uint64_t * indices_end,
466542 const Array& values, int64_t offset,
543+ bool mark_duplicates,
467544 const ArraySortOptions& options, ExecContext*) {
468- return NullPartitionResult::NullsOnly (indices_begin, indices_end,
469- options.null_placement );
545+ auto p = NullPartitionResult::NullsOnly (indices_begin, indices_end,
546+ options.null_placement );
547+ if (mark_duplicates && p.nulls_end != p.nulls_begin ) {
548+ auto it = p.nulls_begin ;
549+ while (++it != p.nulls_end ) {
550+ *it |= kDuplicateMask ;
551+ }
552+ }
553+ return p;
470554 }
471555};
472556
@@ -550,7 +634,9 @@ struct ArraySortIndices {
550634 ArrayType arr (batch[0 ].array .ToArrayData ());
551635 ARROW_ASSIGN_OR_RAISE (auto sorter, GetArraySorter (*GetPhysicalType (arr.type ())));
552636
553- return sorter (out_begin, out_end, arr, 0 , options, ctx->exec_context ()).status ();
637+ return sorter (out_begin, out_end, arr, /* offset=*/ 0 , /* mark_duplicates=*/ false ,
638+ options, ctx->exec_context ())
639+ .status ();
554640 }
555641};
556642
@@ -561,9 +647,9 @@ Status ArraySortIndicesChunked(KernelContext* ctx, const ExecBatch& batch, Datum
561647 uint64_t * out_begin = out_arr->GetMutableValues <uint64_t >(1 );
562648 uint64_t * out_end = out_begin + out_arr->length ;
563649 std::iota (out_begin, out_end, 0 );
564- return SortChunkedArray (ctx-> exec_context (), out_begin, out_end,
565- *batch[0 ].chunked_array (), options. order ,
566- options.null_placement )
650+ return SortChunkedArray (
651+ ctx-> exec_context (), out_begin, out_end, *batch[0 ].chunked_array (),
652+ /* mark_duplicates= */ false , options. order , options.null_placement )
567653 .status ();
568654}
569655
0 commit comments