diff --git a/be/src/pipeline/exec/streaming_aggregation_operator.cpp b/be/src/pipeline/exec/streaming_aggregation_operator.cpp index 3a8638402a8a89..77673a4a3d37a8 100644 --- a/be/src/pipeline/exec/streaming_aggregation_operator.cpp +++ b/be/src/pipeline/exec/streaming_aggregation_operator.cpp @@ -100,6 +100,8 @@ Status StreamingAggLocalState::init(RuntimeState* state, LocalStateInfo& info) { _insert_values_to_column_timer = ADD_TIMER(Base::custom_profile(), "InsertValuesToColumnTime"); _deserialize_data_timer = ADD_TIMER(Base::custom_profile(), "DeserializeAndMergeTime"); _hash_table_compute_timer = ADD_TIMER(Base::custom_profile(), "HashTableComputeTime"); + _hash_table_limit_compute_timer = + ADD_TIMER(Base::custom_profile(), "HashTableLimitComputeTime"); _hash_table_emplace_timer = ADD_TIMER(Base::custom_profile(), "HashTableEmplaceTime"); _hash_table_input_counter = ADD_COUNTER(Base::custom_profile(), "HashTableInputCount", TUnit::UNIT); @@ -153,16 +155,10 @@ Status StreamingAggLocalState::open(RuntimeState* state) { }}, _agg_data->method_variant); - if (p._is_merge || p._needs_finalize) { - return Status::InvalidArgument( - "StreamingAggLocalState only support no merge and no finalize, " - "but got is_merge={}, needs_finalize={}", - p._is_merge, p._needs_finalize); - } - - _should_limit_output = p._limit != -1 && // has limit - (!p._have_conjuncts) && // no having conjunct - p._needs_finalize; // agg's finalize step + limit = p._sort_limit; + do_sort_limit = p._do_sort_limit; + null_directions = p._null_directions; + order_directions = p._order_directions; return Status::OK(); } @@ -364,6 +360,30 @@ Status StreamingAggLocalState::_pre_agg_with_serialized_key(doris::vectorized::B _places.resize(rows); if (_should_not_do_pre_agg(rows)) { + if (limit > 0) { + DCHECK(do_sort_limit); + if (need_do_sort_limit == -1) { + const size_t hash_table_size = _get_hash_table_size(); + need_do_sort_limit = hash_table_size >= limit ? 1 : 0; + if (need_do_sort_limit == 1) { + build_limit_heap(hash_table_size); + } + } + + if (need_do_sort_limit == 1) { + if (_do_limit_filter(rows, key_columns)) { + bool need_filter = std::find(need_computes.begin(), need_computes.end(), 1) != + need_computes.end(); + if (need_filter) { + _add_limit_heap_top(key_columns, rows); + vectorized::Block::filter_block_internal(in_block, need_computes); + rows = (uint32_t)in_block->rows(); + } else { + return Status::OK(); + } + } + } + } bool mem_reuse = p._make_nullable_keys.empty() && out_block->mem_reuse(); std::vector data_types; @@ -405,12 +425,23 @@ Status StreamingAggLocalState::_pre_agg_with_serialized_key(doris::vectorized::B } } } else { - _emplace_into_hash_table(_places.data(), key_columns, rows); + bool need_agg = true; + if (need_do_sort_limit != 1) { + _emplace_into_hash_table(_places.data(), key_columns, rows); + } else { + need_agg = _emplace_into_hash_table_limit(_places.data(), in_block, key_columns, rows); + } - for (int i = 0; i < _aggregate_evaluators.size(); ++i) { - RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add( - in_block, p._offsets_of_aggregate_states[i], _places.data(), _agg_arena_pool, - _should_expand_hash_table)); + if (need_agg) { + for (int i = 0; i < _aggregate_evaluators.size(); ++i) { + RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add( + in_block, p._offsets_of_aggregate_states[i], _places.data(), + _agg_arena_pool, _should_expand_hash_table)); + } + if (limit > 0 && need_do_sort_limit == -1 && _get_hash_table_size() >= limit) { + need_do_sort_limit = 1; + build_limit_heap(_get_hash_table_size()); + } } } @@ -562,6 +593,183 @@ void StreamingAggLocalState::_destroy_agg_status(vectorized::AggregateDataPtr da } } +vectorized::MutableColumns StreamingAggLocalState::_get_keys_hash_table() { + return std::visit( + vectorized::Overload { + [&](std::monostate& arg) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); + return vectorized::MutableColumns(); + }, + [&](auto&& agg_method) -> vectorized::MutableColumns { + vectorized::MutableColumns key_columns; + for (int i = 0; i < _probe_expr_ctxs.size(); ++i) { + key_columns.emplace_back( + _probe_expr_ctxs[i]->root()->data_type()->create_column()); + } + auto& data = *agg_method.hash_table; + bool has_null_key = data.has_null_key_data(); + const auto size = data.size() - has_null_key; + using KeyType = std::decay_t::Key; + std::vector keys(size); + + uint32_t num_rows = 0; + auto iter = _aggregate_data_container->begin(); + { + while (iter != _aggregate_data_container->end()) { + keys[num_rows] = iter.get_key(); + ++iter; + ++num_rows; + } + } + agg_method.insert_keys_into_columns(keys, key_columns, num_rows); + if (has_null_key) { + key_columns[0]->insert_data(nullptr, 0); + } + return key_columns; + }}, + _agg_data->method_variant); +} + +void StreamingAggLocalState::build_limit_heap(size_t hash_table_size) { + limit_columns = _get_keys_hash_table(); + for (size_t i = 0; i < hash_table_size; ++i) { + limit_heap.emplace(i, limit_columns, order_directions, null_directions); + } + while (hash_table_size > limit) { + limit_heap.pop(); + hash_table_size--; + } + limit_columns_min = limit_heap.top()._row_id; +} + +void StreamingAggLocalState::_add_limit_heap_top(vectorized::ColumnRawPtrs& key_columns, + size_t rows) { + for (int i = 0; i < rows; ++i) { + if (cmp_res[i] == 1 && need_computes[i]) { + for (int j = 0; j < key_columns.size(); ++j) { + limit_columns[j]->insert_from(*key_columns[j], i); + } + limit_heap.emplace(limit_columns[0]->size() - 1, limit_columns, order_directions, + null_directions); + limit_heap.pop(); + limit_columns_min = limit_heap.top()._row_id; + break; + } + } +} + +void StreamingAggLocalState::_refresh_limit_heap(size_t i, vectorized::ColumnRawPtrs& key_columns) { + for (int j = 0; j < key_columns.size(); ++j) { + limit_columns[j]->insert_from(*key_columns[j], i); + } + limit_heap.emplace(limit_columns[0]->size() - 1, limit_columns, order_directions, + null_directions); + limit_heap.pop(); + limit_columns_min = limit_heap.top()._row_id; +} + +bool StreamingAggLocalState::_emplace_into_hash_table_limit(vectorized::AggregateDataPtr* places, + vectorized::Block* block, + vectorized::ColumnRawPtrs& key_columns, + uint32_t num_rows) { + return std::visit( + vectorized::Overload { + [&](std::monostate& arg) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); + return true; + }, + [&](auto&& agg_method) -> bool { + SCOPED_TIMER(_hash_table_compute_timer); + using HashMethodType = std::decay_t; + using AggState = typename HashMethodType::State; + + bool need_filter = _do_limit_filter(num_rows, key_columns); + if (auto need_agg = + std::find(need_computes.begin(), need_computes.end(), 1); + need_agg != need_computes.end()) { + if (need_filter) { + vectorized::Block::filter_block_internal(block, need_computes); + num_rows = (uint32_t)block->rows(); + } + + AggState state(key_columns); + agg_method.init_serialized_keys(key_columns, num_rows); + size_t i = 0; + + auto creator = [&](const auto& ctor, auto& key, auto& origin) { + try { + HashMethodType::try_presis_key_and_origin(key, origin, + _agg_arena_pool); + auto mapped = _aggregate_data_container->append_data(origin); + auto st = _create_agg_status(mapped); + if (!st) { + throw Exception(st.code(), st.to_string()); + } + ctor(key, mapped); + _refresh_limit_heap(i, key_columns); + } catch (...) { + // Exception-safety - if it can not allocate memory or create status, + // the destructors will not be called. + ctor(key, nullptr); + throw; + } + }; + + auto creator_for_null_key = [&](auto& mapped) { + mapped = _agg_arena_pool.aligned_alloc( + Base::_parent->template cast() + ._total_size_of_aggregate_states, + Base::_parent->template cast() + ._align_aggregate_states); + auto st = _create_agg_status(mapped); + if (!st) { + throw Exception(st.code(), st.to_string()); + } + _refresh_limit_heap(i, key_columns); + }; + + SCOPED_TIMER(_hash_table_emplace_timer); + for (i = 0; i < num_rows; ++i) { + places[i] = *agg_method.lazy_emplace(state, i, creator, + creator_for_null_key); + } + COUNTER_UPDATE(_hash_table_input_counter, num_rows); + return true; + } + return false; + }}, + _agg_data->method_variant); +} + +bool StreamingAggLocalState::_do_limit_filter(size_t num_rows, + vectorized::ColumnRawPtrs& key_columns) { + SCOPED_TIMER(_hash_table_limit_compute_timer); + if (num_rows) { + cmp_res.resize(num_rows); + need_computes.resize(num_rows); + memset(need_computes.data(), 0, need_computes.size()); + memset(cmp_res.data(), 0, cmp_res.size()); + + const auto key_size = null_directions.size(); + for (int i = 0; i < key_size; i++) { + key_columns[i]->compare_internal(limit_columns_min, *limit_columns[i], + null_directions[i], order_directions[i], cmp_res, + need_computes.data()); + } + + auto set_computes_arr = [](auto* __restrict res, auto* __restrict computes, size_t rows) { + for (size_t i = 0; i < rows; ++i) { + computes[i] = computes[i] == res[i]; + } + }; + set_computes_arr(cmp_res.data(), need_computes.data(), num_rows); + + return std::find(need_computes.begin(), need_computes.end(), 0) != need_computes.end(); + } + + return false; +} + void StreamingAggLocalState::_emplace_into_hash_table(vectorized::AggregateDataPtr* places, vectorized::ColumnRawPtrs& key_columns, const uint32_t num_rows) { @@ -617,7 +825,6 @@ StreamingAggOperatorX::StreamingAggOperatorX(ObjectPool* pool, int operator_id, _intermediate_tuple_id(tnode.agg_node.intermediate_tuple_id), _output_tuple_id(tnode.agg_node.output_tuple_id), _needs_finalize(tnode.agg_node.need_finalize), - _is_merge(false), _is_first_phase(tnode.agg_node.__isset.is_first_phase && tnode.agg_node.is_first_phase), _have_conjuncts(tnode.__isset.vconjunct && !tnode.vconjunct.nodes.empty()), _agg_fn_output_row_descriptor(descs, tnode.row_tuples), @@ -669,8 +876,33 @@ Status StreamingAggOperatorX::init(const TPlanNode& tnode, RuntimeState* state) } const auto& agg_functions = tnode.agg_node.aggregate_functions; - _is_merge = std::any_of(agg_functions.cbegin(), agg_functions.cend(), - [](const auto& e) { return e.nodes[0].agg_expr.is_merge_agg; }); + auto is_merge = std::any_of(agg_functions.cbegin(), agg_functions.cend(), + [](const auto& e) { return e.nodes[0].agg_expr.is_merge_agg; }); + if (is_merge || _needs_finalize) { + return Status::InvalidArgument( + "StreamingAggLocalState only support no merge and no finalize, " + "but got is_merge={}, needs_finalize={}", + is_merge, _needs_finalize); + } + + // Handle sort limit + if (tnode.agg_node.__isset.agg_sort_info_by_group_key) { + _sort_limit = _limit; + _limit = -1; + _do_sort_limit = true; + const auto& agg_sort_info = tnode.agg_node.agg_sort_info_by_group_key; + DCHECK_EQ(agg_sort_info.nulls_first.size(), agg_sort_info.is_asc_order.size()); + + const size_t order_by_key_size = agg_sort_info.is_asc_order.size(); + _order_directions.resize(order_by_key_size); + _null_directions.resize(order_by_key_size); + for (int i = 0; i < order_by_key_size; ++i) { + _order_directions[i] = agg_sort_info.is_asc_order[i] ? 1 : -1; + _null_directions[i] = + agg_sort_info.nulls_first[i] ? -_order_directions[i] : _order_directions[i]; + } + } + _op_name = "STREAMING_AGGREGATION_OPERATOR"; return Status::OK(); } diff --git a/be/src/pipeline/exec/streaming_aggregation_operator.h b/be/src/pipeline/exec/streaming_aggregation_operator.h index a0c5985a1b9f22..08e28ab5e9fefe 100644 --- a/be/src/pipeline/exec/streaming_aggregation_operator.h +++ b/be/src/pipeline/exec/streaming_aggregation_operator.h @@ -48,6 +48,7 @@ class StreamingAggLocalState MOCK_REMOVE(final) : public PipelineXLocalState _agg_profile_arena = nullptr; std::unique_ptr _aggregate_data_container = nullptr; - bool _should_limit_output = false; bool _reach_limit = false; size_t _input_num_rows = 0; + int64_t limit = -1; + int need_do_sort_limit = -1; + bool do_sort_limit = false; + vectorized::MutableColumns limit_columns; + int limit_columns_min = -1; + vectorized::PaddedPODArray need_computes; + std::vector cmp_res; + std::vector order_directions; + std::vector null_directions; + + struct HeapLimitCursor { + HeapLimitCursor(int row_id, vectorized::MutableColumns& limit_columns, + std::vector& order_directions, std::vector& null_directions) + : _row_id(row_id), + _limit_columns(limit_columns), + _order_directions(order_directions), + _null_directions(null_directions) {} + + HeapLimitCursor(const HeapLimitCursor& other) = default; + + HeapLimitCursor(HeapLimitCursor&& other) noexcept + : _row_id(other._row_id), + _limit_columns(other._limit_columns), + _order_directions(other._order_directions), + _null_directions(other._null_directions) {} + + HeapLimitCursor& operator=(const HeapLimitCursor& other) noexcept { + _row_id = other._row_id; + return *this; + } + + HeapLimitCursor& operator=(HeapLimitCursor&& other) noexcept { + _row_id = other._row_id; + return *this; + } + + bool operator<(const HeapLimitCursor& rhs) const { + for (int i = 0; i < _limit_columns.size(); ++i) { + const auto& _limit_column = _limit_columns[i]; + auto res = _limit_column->compare_at(_row_id, rhs._row_id, *_limit_column, + _null_directions[i]) * + _order_directions[i]; + if (res < 0) { + return true; + } else if (res > 0) { + return false; + } + } + return false; + } + + int _row_id; + vectorized::MutableColumns& _limit_columns; + std::vector& _order_directions; + std::vector& _null_directions; + }; + + std::priority_queue limit_heap; + + vectorized::MutableColumns _get_keys_hash_table(); + vectorized::PODArray _places; std::vector _deserialize_buffer; @@ -180,7 +249,6 @@ class StreamingAggOperatorX MOCK_REMOVE(final) : public StatefulOperatorX _make_nullable_keys; bool _have_conjuncts; RowDescriptor _agg_fn_output_row_descriptor; + + // For sort limit + bool _do_sort_limit = false; + int64_t _sort_limit = -1; + std::vector _order_directions; + std::vector _null_directions; + const std::vector _partition_exprs; }; diff --git a/be/test/pipeline/operator/streaming_agg_operator_test.cpp b/be/test/pipeline/operator/streaming_agg_operator_test.cpp index 91ca56572be9a8..68462d4c3c2605 100644 --- a/be/test/pipeline/operator/streaming_agg_operator_test.cpp +++ b/be/test/pipeline/operator/streaming_agg_operator_test.cpp @@ -109,7 +109,6 @@ TEST_F(StreamingAggOperatorTest, test1) { false)); op->_pool = &pool; op->_needs_finalize = false; - op->_is_merge = false; EXPECT_TRUE(op->set_child(child_op)); @@ -166,7 +165,6 @@ TEST_F(StreamingAggOperatorTest, test2) { false)); op->_pool = &pool; op->_needs_finalize = false; - op->_is_merge = false; EXPECT_TRUE(op->set_child(child_op)); @@ -243,7 +241,6 @@ TEST_F(StreamingAggOperatorTest, test3) { false)); op->_pool = &pool; op->_needs_finalize = false; - op->_is_merge = false; EXPECT_TRUE(op->set_child(child_op)); @@ -323,7 +320,6 @@ TEST_F(StreamingAggOperatorTest, test4) { std::make_shared(), false)); op->_pool = &pool; op->_needs_finalize = false; - op->_is_merge = false; EXPECT_TRUE(op->set_child(child_op)); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index 84e7c400269cf3..3ccb03d20d80ae 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -338,8 +338,7 @@ public PlanFragment visitPhysicalDistribute(PhysicalDistribute d if (upstreamFragment.getPlanRoot() instanceof AggregationNode && upstream instanceof PhysicalHashAggregate) { PhysicalHashAggregate hashAggregate = (PhysicalHashAggregate) upstream; if (hashAggregate.getAggPhase() == AggPhase.LOCAL - && hashAggregate.getAggMode() == AggMode.INPUT_TO_BUFFER - && hashAggregate.getTopnPushInfo() == null) { + && hashAggregate.getAggMode() == AggMode.INPUT_TO_BUFFER) { AggregationNode aggregationNode = (AggregationNode) upstreamFragment.getPlanRoot(); aggregationNode.setUseStreamingPreagg(hashAggregate.isMaybeUsingStream()); } diff --git a/regression-test/suites/nereids_tpch_p0/tpch/push_topn_to_agg.groovy b/regression-test/suites/nereids_tpch_p0/tpch/push_topn_to_agg.groovy index 06975eef5eaa29..5e694b4781d1f7 100644 --- a/regression-test/suites/nereids_tpch_p0/tpch/push_topn_to_agg.groovy +++ b/regression-test/suites/nereids_tpch_p0/tpch/push_topn_to_agg.groovy @@ -32,7 +32,6 @@ suite("push_topn_to_agg") { explain{ sql "select o_custkey, sum(o_shippriority) from orders group by o_custkey limit 4;" multiContains ("sortByGroupKey:true", 2) - notContains("STREAMING") } // when apply this opt, trun off STREAMING @@ -40,14 +39,12 @@ suite("push_topn_to_agg") { explain{ sql "select sum(c_custkey), c_name from customer group by c_name limit 6;" multiContains ("sortByGroupKey:true", 2) - notContains("STREAMING") } // topn -> agg explain{ sql "select o_custkey, sum(o_shippriority) from orders group by o_custkey order by o_custkey limit 8;" multiContains ("sortByGroupKey:true", 2) - notContains("STREAMING") } // order keys are part of group keys, @@ -185,4 +182,4 @@ suite("push_topn_to_agg") { | planed with unknown column statistics | +--------------------------------------------------------------------------------+ **/ -} \ No newline at end of file +}