diff --git a/cpp/src/arrow/compute/kernels/vector_rank.cc b/cpp/src/arrow/compute/kernels/vector_rank.cc index 50af9c6d599..e0069a1f2c4 100644 --- a/cpp/src/arrow/compute/kernels/vector_rank.cc +++ b/cpp/src/arrow/compute/kernels/vector_rank.cc @@ -63,21 +63,113 @@ void MarkDuplicates(const NullPartitionResult& sorted, ValueSelector&& value_sel } } -struct RankingsEmitter { - virtual ~RankingsEmitter() = default; - virtual bool NeedsDuplicates() = 0; - virtual Result CreateRankings(ExecContext* ctx, - const NullPartitionResult& sorted) = 0; +const RankOptions* GetDefaultRankOptions() { + static const auto kDefaultRankOptions = RankOptions::Defaults(); + return &kDefaultRankOptions; +} + +const RankPercentileOptions* GetDefaultPercentileRankOptions() { + static const auto kDefaultPercentileRankOptions = RankPercentileOptions::Defaults(); + return &kDefaultPercentileRankOptions; +} + +template +Result DoSortAndMarkDuplicate( + ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end, const Array& input, + const std::shared_ptr& physical_type, const SortOrder order, + const NullPlacement null_placement, bool needs_duplicates) { + using GetView = GetViewType; + using ArrayType = typename TypeTraits::ArrayType; + + ARROW_ASSIGN_OR_RAISE(auto array_sorter, GetArraySorter(*physical_type)); + + ArrayType array(input.data()); + ARROW_ASSIGN_OR_RAISE(auto sorted, + array_sorter(indices_begin, indices_end, array, 0, + ArraySortOptions(order, null_placement), ctx)); + + if (needs_duplicates) { + auto value_selector = [&array](int64_t index) { + return GetView::LogicalValue(array.GetView(index)); + }; + MarkDuplicates(sorted, value_selector); + } + return sorted; +} + +template +Result DoSortAndMarkDuplicate( + ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end, + const ChunkedArray& input, const std::shared_ptr& physical_type, + const SortOrder order, const NullPlacement null_placement, bool needs_duplicates) { + auto physical_chunks = GetPhysicalChunks(input, physical_type); + if (physical_chunks.empty()) { + return NullPartitionResult{}; + } + ARROW_ASSIGN_OR_RAISE(auto sorted, + SortChunkedArray(ctx, indices_begin, indices_end, physical_type, + physical_chunks, order, null_placement)); + if (needs_duplicates) { + const auto arrays = GetArrayPointers(physical_chunks); + auto value_selector = [resolver = ChunkedArrayResolver(span(arrays))](int64_t index) { + return resolver.Resolve(index).Value(); + }; + MarkDuplicates(sorted, value_selector); + } + return sorted; +} + +template +class SortAndMarkDuplicate : public TypeVisitor { + public: + SortAndMarkDuplicate(ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end, + const InputType& input, const SortOrder order, + const NullPlacement null_placement, const bool needs_duplicate) + : TypeVisitor(), + ctx_(ctx), + indices_begin_(indices_begin), + indices_end_(indices_end), + input_(input), + order_(order), + null_placement_(null_placement), + needs_duplicates_(needs_duplicate), + physical_type_(GetPhysicalType(input.type())) {} + + Result Run() { + RETURN_NOT_OK(physical_type_->Accept(this)); + return sorted_; + } + +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { \ + ARROW_ASSIGN_OR_RAISE( \ + sorted_, DoSortAndMarkDuplicate(ctx_, indices_begin_, indices_end_, \ + input_, physical_type_, order_, \ + null_placement_, needs_duplicates_)); \ + return Status::OK(); \ + } + + VISIT_SORTABLE_PHYSICAL_TYPES(VISIT) + +#undef VISIT + + private: + ExecContext* ctx_; + uint64_t* indices_begin_; + uint64_t* indices_end_; + const InputType& input_; + const SortOrder order_; + const NullPlacement null_placement_; + const bool needs_duplicates_; + const std::shared_ptr physical_type_; + NullPartitionResult sorted_{}; }; // A helper class that emits rankings for the "rank_percentile" function -struct PercentileRankingsEmitter : public RankingsEmitter { - explicit PercentileRankingsEmitter(double factor) : factor_(factor) {} - - bool NeedsDuplicates() override { return true; } +struct PercentileRanker { + explicit PercentileRanker(double factor) : factor_(factor) {} - Result CreateRankings(ExecContext* ctx, - const NullPartitionResult& sorted) override { + Result CreateRankings(ExecContext* ctx, const NullPartitionResult& sorted) { const int64_t length = sorted.overall_end() - sorted.overall_begin(); ARROW_ASSIGN_OR_RAISE(auto rankings, MakeMutableFloat64Array(length, ctx->memory_pool())); @@ -114,14 +206,10 @@ struct PercentileRankingsEmitter : public RankingsEmitter { }; // A helper class that emits rankings for the "rank" function -struct OrdinalRankingsEmitter : public RankingsEmitter { - explicit OrdinalRankingsEmitter(RankOptions::Tiebreaker tiebreaker) - : tiebreaker_(tiebreaker) {} +struct OrdinalRanker { + explicit OrdinalRanker(RankOptions::Tiebreaker tiebreaker) : tiebreaker_(tiebreaker) {} - bool NeedsDuplicates() override { return tiebreaker_ != RankOptions::First; } - - Result CreateRankings(ExecContext* ctx, - const NullPartitionResult& sorted) override { + Result CreateRankings(ExecContext* ctx, const NullPartitionResult& sorted) { const int64_t length = sorted.overall_end() - sorted.overall_begin(); ARROW_ASSIGN_OR_RAISE(auto rankings, MakeMutableUInt64Array(length, ctx->memory_pool())); @@ -186,119 +274,6 @@ struct OrdinalRankingsEmitter : public RankingsEmitter { const RankOptions::Tiebreaker tiebreaker_; }; -const RankOptions* GetDefaultRankOptions() { - static const auto kDefaultRankOptions = RankOptions::Defaults(); - return &kDefaultRankOptions; -} - -const RankPercentileOptions* GetDefaultPercentileRankOptions() { - static const auto kDefaultPercentileRankOptions = RankPercentileOptions::Defaults(); - return &kDefaultPercentileRankOptions; -} - -template -class RankerMixin : public TypeVisitor { - public: - RankerMixin(ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end, - const InputType& input, const SortOrder order, - const NullPlacement null_placement, RankingsEmitter* emitter) - : TypeVisitor(), - ctx_(ctx), - indices_begin_(indices_begin), - indices_end_(indices_end), - input_(input), - order_(order), - null_placement_(null_placement), - physical_type_(GetPhysicalType(input.type())), - emitter_(emitter) {} - - Result Run() { - RETURN_NOT_OK(physical_type_->Accept(this)); - return emitter_->CreateRankings(ctx_, sorted_); - } - -#define VISIT(TYPE) \ - Status Visit(const TYPE& type) { \ - return static_cast(this)->template SortAndMarkDuplicates(); \ - } - - VISIT_SORTABLE_PHYSICAL_TYPES(VISIT) - -#undef VISIT - - protected: - ExecContext* ctx_; - uint64_t* indices_begin_; - uint64_t* indices_end_; - const InputType& input_; - const SortOrder order_; - const NullPlacement null_placement_; - const std::shared_ptr physical_type_; - RankingsEmitter* emitter_; - NullPartitionResult sorted_{}; -}; - -template -class Ranker; - -template <> -class Ranker : public RankerMixin> { - public: - using RankerMixin::RankerMixin; - - template - Status SortAndMarkDuplicates() { - using GetView = GetViewType; - using ArrayType = typename TypeTraits::ArrayType; - - ARROW_ASSIGN_OR_RAISE(auto array_sorter, GetArraySorter(*physical_type_)); - - ArrayType array(input_.data()); - ARROW_ASSIGN_OR_RAISE(sorted_, - array_sorter(indices_begin_, indices_end_, array, 0, - ArraySortOptions(order_, null_placement_), ctx_)); - - if (emitter_->NeedsDuplicates()) { - auto value_selector = [&array](int64_t index) { - return GetView::LogicalValue(array.GetView(index)); - }; - MarkDuplicates(sorted_, value_selector); - } - return Status::OK(); - } -}; - -template <> -class Ranker : public RankerMixin> { - public: - template - explicit Ranker(Args&&... args) - : RankerMixin(std::forward(args)...), - physical_chunks_(GetPhysicalChunks(input_, physical_type_)) {} - - template - Status SortAndMarkDuplicates() { - if (physical_chunks_.empty()) { - return Status::OK(); - } - ARROW_ASSIGN_OR_RAISE( - sorted_, SortChunkedArray(ctx_, indices_begin_, indices_end_, physical_type_, - physical_chunks_, order_, null_placement_)); - if (emitter_->NeedsDuplicates()) { - const auto arrays = GetArrayPointers(physical_chunks_); - auto value_selector = [resolver = - ChunkedArrayResolver(span(arrays))](int64_t index) { - return resolver.Resolve(index).Value(); - }; - MarkDuplicates(sorted_, value_selector); - } - return Status::OK(); - } - - private: - const ArrayVector physical_chunks_; -}; - const FunctionDoc rank_doc( "Compute ordinal ranks of an array (1-based)", ("This function computes a rank of the input array.\n" @@ -324,6 +299,7 @@ const FunctionDoc rank_percentile_doc( "in RankPercentileOptions."), {"input"}, "RankPercentileOptions"); +template class RankMetaFunctionBase : public MetaFunction { public: using MetaFunction::MetaFunction; @@ -348,18 +324,16 @@ class RankMetaFunctionBase : public MetaFunction { } protected: - struct UnpackedOptions { - SortOrder order{SortOrder::Ascending}; - NullPlacement null_placement; - std::unique_ptr emitter; - }; - - virtual UnpackedOptions UnpackOptions(const FunctionOptions&) const = 0; - template Result Rank(const T& input, const FunctionOptions& function_options, ExecContext* ctx) const { - auto options = UnpackOptions(function_options); + const auto& options = + checked_cast(function_options); + + SortOrder order = SortOrder::Ascending; + if (!options.sort_keys.empty()) { + order = options.sort_keys[0].order; + } int64_t length = input.length(); ARROW_ASSIGN_OR_RAISE(auto indices, @@ -367,47 +341,49 @@ class RankMetaFunctionBase : public MetaFunction { auto* indices_begin = indices->GetMutableValues(1); auto* indices_end = indices_begin + length; std::iota(indices_begin, indices_end, 0); + auto needs_duplicates = Derived::NeedsDuplicates(options); + ARROW_ASSIGN_OR_RAISE( + auto sorted, SortAndMarkDuplicate(ctx, indices_begin, indices_end, input, order, + options.null_placement, needs_duplicates) + .Run()); - Ranker ranker(ctx, indices_begin, indices_end, input, options.order, - options.null_placement, options.emitter.get()); - return ranker.Run(); + auto ranker = Derived::GetRanker(options); + return ranker.CreateRankings(ctx, sorted); } }; -class RankMetaFunction : public RankMetaFunctionBase { +class RankMetaFunction : public RankMetaFunctionBase { public: - RankMetaFunction() - : RankMetaFunctionBase("rank", Arity::Unary(), rank_doc, GetDefaultRankOptions()) {} + using FunctionOptionsType = RankOptions; + using RankerType = OrdinalRanker; - protected: - UnpackedOptions UnpackOptions(const FunctionOptions& function_options) const override { - const auto& options = checked_cast(function_options); - UnpackedOptions unpacked{ - SortOrder::Ascending, options.null_placement, - std::make_unique(options.tiebreaker)}; - if (!options.sort_keys.empty()) { - unpacked.order = options.sort_keys[0].order; - } - return unpacked; + static bool NeedsDuplicates(const RankOptions& options) { + return options.tiebreaker != RankOptions::First; + } + + static RankerType GetRanker(const RankOptions& options) { + return RankerType(options.tiebreaker); } + + RankMetaFunction() + : RankMetaFunctionBase("rank", Arity::Unary(), rank_doc, GetDefaultRankOptions()) {} }; -class RankPercentileMetaFunction : public RankMetaFunctionBase { +class RankPercentileMetaFunction + : public RankMetaFunctionBase { public: + using FunctionOptionsType = RankPercentileOptions; + using RankerType = PercentileRanker; + + static bool NeedsDuplicates(const RankPercentileOptions&) { return true; } + + static RankerType GetRanker(const RankPercentileOptions& options) { + return RankerType(options.factor); + } + RankPercentileMetaFunction() : RankMetaFunctionBase("rank_percentile", Arity::Unary(), rank_percentile_doc, GetDefaultPercentileRankOptions()) {} - - protected: - UnpackedOptions UnpackOptions(const FunctionOptions& function_options) const override { - const auto& options = checked_cast(function_options); - UnpackedOptions unpacked{SortOrder::Ascending, options.null_placement, - std::make_unique(options.factor)}; - if (!options.sort_keys.empty()) { - unpacked.order = options.sort_keys[0].order; - } - return unpacked; - } }; } // namespace