Skip to content

Commit dfb1c98

Browse files
committed
[C++][Compute] Refactor rank function implementation
1 parent 04249b9 commit dfb1c98

File tree

7 files changed

+330
-177
lines changed

7 files changed

+330
-177
lines changed

cpp/src/arrow/compute/kernels/chunked_internal.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,9 @@ Status ChunkedIndexMapper::PhysicalToLogical() {
112112
DCHECK_LT(loc.chunk_index(), chunk_offsets.size());
113113
DCHECK_LT(loc.index_in_chunk(),
114114
static_cast<uint64_t>(chunk_lengths_[loc.chunk_index()]));
115-
indices_begin_[i] =
116-
chunk_offsets[loc.chunk_index()] + static_cast<int64_t>(loc.index_in_chunk());
115+
const uint64_t logical_index =
116+
static_cast<uint64_t>(chunk_offsets[loc.chunk_index()]) + loc.index_in_chunk();
117+
indices_begin_[i] = logical_index | loc.logical_duplicate_mask();
117118
}
118119

119120
return Status::OK();

cpp/src/arrow/compute/kernels/chunked_internal.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,29 @@ struct ResolvedChunk {
6262
// (see ChunkedIndexMapper)
6363
struct CompressedChunkLocation {
6464
static constexpr int kChunkIndexBits = 24;
65-
static constexpr int KIndexInChunkBits = 64 - kChunkIndexBits;
65+
static constexpr int KIndexInChunkBits = 63 - kChunkIndexBits;
6666

6767
static constexpr uint64_t kMaxChunkIndex = (1ULL << kChunkIndexBits) - 1;
6868
static constexpr uint64_t kMaxIndexInChunk = (1ULL << KIndexInChunkBits) - 1;
69+
// An optional bit between the two indices to mark duplicates for the rank function.
70+
static constexpr uint64_t kCompressedDuplicateMask = 1ULL << kChunkIndexBits;
6971

7072
CompressedChunkLocation() = default;
7173

7274
constexpr uint64_t chunk_index() const { return data_ & kMaxChunkIndex; }
73-
constexpr uint64_t index_in_chunk() const { return data_ >> kChunkIndexBits; }
75+
constexpr uint64_t index_in_chunk() const { return data_ >> (kChunkIndexBits + 1); }
76+
constexpr bool is_duplicate() const { return (data_ & kCompressedDuplicateMask) != 0; }
77+
78+
// Turn the duplicate bit into a duplicate mask for logical indices
79+
constexpr uint64_t logical_duplicate_mask() const {
80+
return (data_ & kCompressedDuplicateMask) << (63 - kChunkIndexBits);
81+
}
82+
83+
void MarkDuplicate() { data_ |= kCompressedDuplicateMask; }
7484

7585
explicit constexpr CompressedChunkLocation(uint64_t chunk_index,
7686
uint64_t index_in_chunk)
77-
: data_((index_in_chunk << kChunkIndexBits) | chunk_index) {}
87+
: data_((index_in_chunk << (kChunkIndexBits + 1)) | chunk_index) {}
7888

7989
template <typename IndexType>
8090
explicit operator TypedChunkLocation<IndexType>() {

cpp/src/arrow/compute/kernels/vector_array_sort.cc

Lines changed: 104 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,45 @@ inline void VisitRawValuesInline(const ArraySpan& values,
138138
}
139139

140140
template <typename ArrowType>
141-
class ArrayCompareSorter {
141+
class ArraySorterMixin {
142+
protected:
142143
using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
143144
using GetView = GetViewType<ArrowType>;
145+
using ValueType =
146+
decltype(GetView::LogicalValue(std::declval<ArrayType>().GetView(uint64_t(0))));
147+
148+
void MarkDuplicates(NullPartitionResult p, const ArrayType& values,
149+
int64_t offset) const {
150+
// TODO GH-45193: what about NaNs vs. actual nulls?
151+
if (p.non_nulls_end != p.non_nulls_begin) {
152+
auto it = p.non_nulls_begin;
153+
ValueType prev_value = GetView::LogicalValue(values.GetView(*it - offset));
154+
while (++it < p.non_nulls_end) {
155+
ValueType curr_value = GetView::LogicalValue(values.GetView(*it - offset));
156+
if (curr_value == prev_value) {
157+
*it |= kDuplicateMask;
158+
}
159+
prev_value = curr_value;
160+
}
161+
}
162+
if (p.nulls_end != p.nulls_begin) {
163+
auto it = p.nulls_begin;
164+
while (++it < p.nulls_end) {
165+
*it |= kDuplicateMask;
166+
}
167+
}
168+
}
169+
};
170+
171+
template <typename ArrowType, typename Base = ArraySorterMixin<ArrowType>>
172+
class ArrayCompareSorter : Base {
173+
using ArrayType = typename Base::ArrayType;
174+
using GetView = typename Base::GetView;
144175

145176
public:
146177
Result<NullPartitionResult> operator()(uint64_t* indices_begin, uint64_t* indices_end,
147178
const Array& array, int64_t offset,
179+
bool mark_duplicates,
148180
const ArraySortOptions& options, ExecContext*) {
149181
const auto& values = checked_cast<const ArrayType&>(array);
150182

@@ -169,6 +201,9 @@ class ArrayCompareSorter {
169201
return rhs < lhs;
170202
});
171203
}
204+
if (mark_duplicates) {
205+
Base::MarkDuplicates(p, values, offset);
206+
}
172207
return p;
173208
}
174209
};
@@ -178,6 +213,7 @@ class ArrayCompareSorter<DictionaryType> {
178213
public:
179214
Result<NullPartitionResult> operator()(uint64_t* indices_begin, uint64_t* indices_end,
180215
const Array& array, int64_t offset,
216+
bool mark_duplicates,
181217
const ArraySortOptions& options,
182218
ExecContext* ctx) {
183219
const auto& dict_array = checked_cast<const DictionaryArray&>(array);
@@ -220,7 +256,8 @@ class ArrayCompareSorter<DictionaryType> {
220256
DCHECK_EQ(decoded_ranks->length(), dict_array.length());
221257
ARROW_ASSIGN_OR_RAISE(auto rank_sorter, GetArraySorter(*decoded_ranks->type()));
222258

223-
return rank_sorter(indices_begin, indices_end, *decoded_ranks, offset, options, ctx);
259+
return rank_sorter(indices_begin, indices_end, *decoded_ranks, offset,
260+
mark_duplicates, options, ctx);
224261
}
225262

226263
private:
@@ -267,17 +304,22 @@ class ArrayCompareSorter<StructType> {
267304
public:
268305
Result<NullPartitionResult> operator()(uint64_t* indices_begin, uint64_t* indices_end,
269306
const Array& array, int64_t offset,
307+
bool mark_duplicates,
270308
const ArraySortOptions& options,
271309
ExecContext* ctx) {
272310
const auto& struct_array = checked_cast<const StructArray&>(array);
311+
if (mark_duplicates) {
312+
// TODO (but rank currently doesn't support struct arrays)
313+
return Status::NotImplemented("Marking duplicates not supported for StructArray");
314+
}
273315
return SortStructArray(ctx, indices_begin, indices_end, struct_array, options.order,
274316
options.null_placement);
275317
}
276318
};
277319

278-
template <typename ArrowType>
279-
class ArrayCountSorter {
280-
using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
320+
template <typename ArrowType, typename Base = ArraySorterMixin<ArrowType>>
321+
class ArrayCountSorter : Base {
322+
using ArrayType = typename Base::ArrayType;
281323
using c_type = typename ArrowType::c_type;
282324

283325
public:
@@ -293,16 +335,24 @@ class ArrayCountSorter {
293335

294336
Result<NullPartitionResult> operator()(uint64_t* indices_begin, uint64_t* indices_end,
295337
const Array& array, int64_t offset,
338+
bool mark_duplicates,
296339
const ArraySortOptions& options,
297340
ExecContext*) const {
298341
const auto& values = checked_cast<const ArrayType&>(array);
342+
NullPartitionResult p;
299343

300344
// 32bit counter performs much better than 64bit one
301345
if (values.length() < (1LL << 32)) {
302-
return SortInternal<uint32_t>(indices_begin, indices_end, values, offset, options);
346+
p = SortInternal<uint32_t>(indices_begin, indices_end, values, offset, options);
303347
} else {
304-
return SortInternal<uint64_t>(indices_begin, indices_end, values, offset, options);
348+
p = SortInternal<uint64_t>(indices_begin, indices_end, values, offset, options);
349+
}
350+
if (mark_duplicates) {
351+
// Perhaps we can mark duplicates slightly faster by doing it directly
352+
// in EmitIndices()? It probably doesn't matter for real-world tasks.
353+
Base::MarkDuplicates(p, values, offset);
305354
}
355+
return p;
306356
}
307357

308358
private:
@@ -378,6 +428,7 @@ class ArrayCountSorter<BooleanType> {
378428

379429
Result<NullPartitionResult> operator()(uint64_t* indices_begin, uint64_t* indices_end,
380430
const Array& array, int64_t offset,
431+
bool mark_duplicates,
381432
const ArraySortOptions& options, ExecContext*) {
382433
const auto& values = checked_cast<const BooleanArray&>(array);
383434

@@ -405,9 +456,31 @@ class ArrayCountSorter<BooleanType> {
405456
}
406457

407458
int64_t index = offset;
408-
VisitRawValuesInline(
409-
*values.data(), [&](bool v) { p.non_nulls_begin[counts[v]++] = index++; },
410-
[&]() { p.nulls_begin[counts[2]++] = index++; });
459+
if (mark_duplicates) {
460+
std::array<bool, 3> seen{}; // false, true, null (like `counts`)
461+
VisitRawValuesInline(
462+
*values.data(),
463+
[&](bool v) {
464+
if (seen[v]) {
465+
p.non_nulls_begin[counts[v]++] = index++ | kDuplicateMask;
466+
} else {
467+
p.non_nulls_begin[counts[v]++] = index++;
468+
seen[v] = true;
469+
}
470+
},
471+
[&]() {
472+
if (seen[2]) {
473+
p.nulls_begin[counts[2]++] = index++ | kDuplicateMask;
474+
} else {
475+
p.nulls_begin[counts[2]++] = index++;
476+
seen[2] = true;
477+
}
478+
});
479+
} else {
480+
VisitRawValuesInline(
481+
*values.data(), [&](bool v) { p.non_nulls_begin[counts[v]++] = index++; },
482+
[&]() { p.nulls_begin[counts[2]++] = index++; });
483+
}
411484
return p;
412485
}
413486
};
@@ -423,6 +496,7 @@ class ArrayCountOrCompareSorter {
423496
public:
424497
Result<NullPartitionResult> operator()(uint64_t* indices_begin, uint64_t* indices_end,
425498
const Array& array, int64_t offset,
499+
bool mark_duplicates,
426500
const ArraySortOptions& options,
427501
ExecContext* ctx) {
428502
const auto& values = checked_cast<const ArrayType&>(array);
@@ -436,11 +510,13 @@ class ArrayCountOrCompareSorter {
436510
if (static_cast<uint64_t>(max) - static_cast<uint64_t>(min) <=
437511
countsort_max_range_) {
438512
count_sorter_.SetMinMax(min, max);
439-
return count_sorter_(indices_begin, indices_end, values, offset, options, ctx);
513+
return count_sorter_(indices_begin, indices_end, values, offset, mark_duplicates,
514+
options, ctx);
440515
}
441516
}
442517

443-
return compare_sorter_(indices_begin, indices_end, values, offset, options, ctx);
518+
return compare_sorter_(indices_begin, indices_end, values, offset, mark_duplicates,
519+
options, ctx);
444520
}
445521

446522
private:
@@ -464,9 +540,17 @@ class ArrayNullSorter {
464540
public:
465541
Result<NullPartitionResult> operator()(uint64_t* indices_begin, uint64_t* indices_end,
466542
const Array& values, int64_t offset,
543+
bool mark_duplicates,
467544
const ArraySortOptions& options, ExecContext*) {
468-
return NullPartitionResult::NullsOnly(indices_begin, indices_end,
469-
options.null_placement);
545+
auto p = NullPartitionResult::NullsOnly(indices_begin, indices_end,
546+
options.null_placement);
547+
if (mark_duplicates && p.nulls_end != p.nulls_begin) {
548+
auto it = p.nulls_begin;
549+
while (++it != p.nulls_end) {
550+
*it |= kDuplicateMask;
551+
}
552+
}
553+
return p;
470554
}
471555
};
472556

@@ -550,7 +634,9 @@ struct ArraySortIndices {
550634
ArrayType arr(batch[0].array.ToArrayData());
551635
ARROW_ASSIGN_OR_RAISE(auto sorter, GetArraySorter(*GetPhysicalType(arr.type())));
552636

553-
return sorter(out_begin, out_end, arr, 0, options, ctx->exec_context()).status();
637+
return sorter(out_begin, out_end, arr, /*offset=*/0, /*mark_duplicates=*/false,
638+
options, ctx->exec_context())
639+
.status();
554640
}
555641
};
556642

@@ -561,9 +647,9 @@ Status ArraySortIndicesChunked(KernelContext* ctx, const ExecBatch& batch, Datum
561647
uint64_t* out_begin = out_arr->GetMutableValues<uint64_t>(1);
562648
uint64_t* out_end = out_begin + out_arr->length;
563649
std::iota(out_begin, out_end, 0);
564-
return SortChunkedArray(ctx->exec_context(), out_begin, out_end,
565-
*batch[0].chunked_array(), options.order,
566-
options.null_placement)
650+
return SortChunkedArray(
651+
ctx->exec_context(), out_begin, out_end, *batch[0].chunked_array(),
652+
/*mark_duplicates=*/false, options.order, options.null_placement)
567653
.status();
568654
}
569655

0 commit comments

Comments
 (0)