@@ -100,6 +100,8 @@ Status StreamingAggLocalState::init(RuntimeState* state, LocalStateInfo& info) {
100100 _insert_values_to_column_timer = ADD_TIMER (Base::custom_profile (), " InsertValuesToColumnTime" );
101101 _deserialize_data_timer = ADD_TIMER (Base::custom_profile (), " DeserializeAndMergeTime" );
102102 _hash_table_compute_timer = ADD_TIMER (Base::custom_profile (), " HashTableComputeTime" );
103+ _hash_table_limit_compute_timer =
104+ ADD_TIMER (Base::custom_profile (), " HashTableLimitComputeTime" );
103105 _hash_table_emplace_timer = ADD_TIMER (Base::custom_profile (), " HashTableEmplaceTime" );
104106 _hash_table_input_counter =
105107 ADD_COUNTER (Base::custom_profile (), " HashTableInputCount" , TUnit::UNIT);
@@ -153,16 +155,10 @@ Status StreamingAggLocalState::open(RuntimeState* state) {
153155 }},
154156 _agg_data->method_variant );
155157
156- if (p._is_merge || p._needs_finalize ) {
157- return Status::InvalidArgument (
158- " StreamingAggLocalState only support no merge and no finalize, "
159- " but got is_merge={}, needs_finalize={}" ,
160- p._is_merge , p._needs_finalize );
161- }
162-
163- _should_limit_output = p._limit != -1 && // has limit
164- (!p._have_conjuncts ) && // no having conjunct
165- p._needs_finalize ; // agg's finalize step
158+ limit = p._sort_limit ;
159+ do_sort_limit = p._do_sort_limit ;
160+ null_directions = p._null_directions ;
161+ order_directions = p._order_directions ;
166162
167163 return Status::OK ();
168164}
@@ -364,6 +360,30 @@ Status StreamingAggLocalState::_pre_agg_with_serialized_key(doris::vectorized::B
364360 _places.resize (rows);
365361
366362 if (_should_not_do_pre_agg (rows)) {
363+ if (limit > 0 ) {
364+ DCHECK (do_sort_limit);
365+ if (need_do_sort_limit == -1 ) {
366+ const size_t hash_table_size = _get_hash_table_size ();
367+ need_do_sort_limit = hash_table_size >= limit ? 1 : 0 ;
368+ if (need_do_sort_limit == 1 ) {
369+ build_limit_heap (hash_table_size);
370+ }
371+ }
372+
373+ if (need_do_sort_limit == 1 ) {
374+ if (_do_limit_filter (rows, key_columns)) {
375+ bool need_filter = std::find (need_computes.begin (), need_computes.end (), 1 ) !=
376+ need_computes.end ();
377+ if (need_filter) {
378+ _add_limit_heap_top (key_columns, rows);
379+ vectorized::Block::filter_block_internal (in_block, need_computes);
380+ rows = (uint32_t )in_block->rows ();
381+ } else {
382+ return Status::OK ();
383+ }
384+ }
385+ }
386+ }
367387 bool mem_reuse = p._make_nullable_keys .empty () && out_block->mem_reuse ();
368388
369389 std::vector<vectorized::DataTypePtr> data_types;
@@ -405,12 +425,23 @@ Status StreamingAggLocalState::_pre_agg_with_serialized_key(doris::vectorized::B
405425 }
406426 }
407427 } else {
408- _emplace_into_hash_table (_places.data (), key_columns, rows);
428+ bool need_agg = true ;
429+ if (need_do_sort_limit != 1 ) {
430+ _emplace_into_hash_table (_places.data (), key_columns, rows);
431+ } else {
432+ need_agg = _emplace_into_hash_table_limit (_places.data (), in_block, key_columns, rows);
433+ }
409434
410- for (int i = 0 ; i < _aggregate_evaluators.size (); ++i) {
411- RETURN_IF_ERROR (_aggregate_evaluators[i]->execute_batch_add (
412- in_block, p._offsets_of_aggregate_states [i], _places.data (), _agg_arena_pool,
413- _should_expand_hash_table));
435+ if (need_agg) {
436+ for (int i = 0 ; i < _aggregate_evaluators.size (); ++i) {
437+ RETURN_IF_ERROR (_aggregate_evaluators[i]->execute_batch_add (
438+ in_block, p._offsets_of_aggregate_states [i], _places.data (),
439+ _agg_arena_pool, _should_expand_hash_table));
440+ }
441+ if (limit > 0 && need_do_sort_limit == -1 && _get_hash_table_size () >= limit) {
442+ need_do_sort_limit = 1 ;
443+ build_limit_heap (_get_hash_table_size ());
444+ }
414445 }
415446 }
416447
@@ -562,6 +593,183 @@ void StreamingAggLocalState::_destroy_agg_status(vectorized::AggregateDataPtr da
562593 }
563594}
564595
596+ vectorized::MutableColumns StreamingAggLocalState::_get_keys_hash_table () {
597+ return std::visit (
598+ vectorized::Overload {
599+ [&](std::monostate& arg) {
600+ throw doris::Exception (ErrorCode::INTERNAL_ERROR, " uninited hash table" );
601+ return vectorized::MutableColumns ();
602+ },
603+ [&](auto && agg_method) -> vectorized::MutableColumns {
604+ vectorized::MutableColumns key_columns;
605+ for (int i = 0 ; i < _probe_expr_ctxs.size (); ++i) {
606+ key_columns.emplace_back (
607+ _probe_expr_ctxs[i]->root ()->data_type ()->create_column ());
608+ }
609+ auto & data = *agg_method.hash_table ;
610+ bool has_null_key = data.has_null_key_data ();
611+ const auto size = data.size () - has_null_key;
612+ using KeyType = std::decay_t <decltype (agg_method)>::Key;
613+ std::vector<KeyType> keys (size);
614+
615+ uint32_t num_rows = 0 ;
616+ auto iter = _aggregate_data_container->begin ();
617+ {
618+ while (iter != _aggregate_data_container->end ()) {
619+ keys[num_rows] = iter.get_key <KeyType>();
620+ ++iter;
621+ ++num_rows;
622+ }
623+ }
624+ agg_method.insert_keys_into_columns (keys, key_columns, num_rows);
625+ if (has_null_key) {
626+ key_columns[0 ]->insert_data (nullptr , 0 );
627+ }
628+ return key_columns;
629+ }},
630+ _agg_data->method_variant );
631+ }
632+
633+ void StreamingAggLocalState::build_limit_heap (size_t hash_table_size) {
634+ limit_columns = _get_keys_hash_table ();
635+ for (size_t i = 0 ; i < hash_table_size; ++i) {
636+ limit_heap.emplace (i, limit_columns, order_directions, null_directions);
637+ }
638+ while (hash_table_size > limit) {
639+ limit_heap.pop ();
640+ hash_table_size--;
641+ }
642+ limit_columns_min = limit_heap.top ()._row_id ;
643+ }
644+
645+ void StreamingAggLocalState::_add_limit_heap_top (vectorized::ColumnRawPtrs& key_columns,
646+ size_t rows) {
647+ for (int i = 0 ; i < rows; ++i) {
648+ if (cmp_res[i] == 1 && need_computes[i]) {
649+ for (int j = 0 ; j < key_columns.size (); ++j) {
650+ limit_columns[j]->insert_from (*key_columns[j], i);
651+ }
652+ limit_heap.emplace (limit_columns[0 ]->size () - 1 , limit_columns, order_directions,
653+ null_directions);
654+ limit_heap.pop ();
655+ limit_columns_min = limit_heap.top ()._row_id ;
656+ break ;
657+ }
658+ }
659+ }
660+
661+ void StreamingAggLocalState::_refresh_limit_heap (size_t i, vectorized::ColumnRawPtrs& key_columns) {
662+ for (int j = 0 ; j < key_columns.size (); ++j) {
663+ limit_columns[j]->insert_from (*key_columns[j], i);
664+ }
665+ limit_heap.emplace (limit_columns[0 ]->size () - 1 , limit_columns, order_directions,
666+ null_directions);
667+ limit_heap.pop ();
668+ limit_columns_min = limit_heap.top ()._row_id ;
669+ }
670+
671+ bool StreamingAggLocalState::_emplace_into_hash_table_limit (vectorized::AggregateDataPtr* places,
672+ vectorized::Block* block,
673+ vectorized::ColumnRawPtrs& key_columns,
674+ uint32_t num_rows) {
675+ return std::visit (
676+ vectorized::Overload {
677+ [&](std::monostate& arg) {
678+ throw doris::Exception (ErrorCode::INTERNAL_ERROR, " uninited hash table" );
679+ return true ;
680+ },
681+ [&](auto && agg_method) -> bool {
682+ SCOPED_TIMER (_hash_table_compute_timer);
683+ using HashMethodType = std::decay_t <decltype (agg_method)>;
684+ using AggState = typename HashMethodType::State;
685+
686+ bool need_filter = _do_limit_filter (num_rows, key_columns);
687+ if (auto need_agg =
688+ std::find (need_computes.begin (), need_computes.end (), 1 );
689+ need_agg != need_computes.end ()) {
690+ if (need_filter) {
691+ vectorized::Block::filter_block_internal (block, need_computes);
692+ num_rows = (uint32_t )block->rows ();
693+ }
694+
695+ AggState state (key_columns);
696+ agg_method.init_serialized_keys (key_columns, num_rows);
697+ size_t i = 0 ;
698+
699+ auto creator = [&](const auto & ctor, auto & key, auto & origin) {
700+ try {
701+ HashMethodType::try_presis_key_and_origin (key, origin,
702+ _agg_arena_pool);
703+ auto mapped = _aggregate_data_container->append_data (origin);
704+ auto st = _create_agg_status (mapped);
705+ if (!st) {
706+ throw Exception (st.code (), st.to_string ());
707+ }
708+ ctor (key, mapped);
709+ _refresh_limit_heap (i, key_columns);
710+ } catch (...) {
711+ // Exception-safety - if it can not allocate memory or create status,
712+ // the destructors will not be called.
713+ ctor (key, nullptr );
714+ throw ;
715+ }
716+ };
717+
718+ auto creator_for_null_key = [&](auto & mapped) {
719+ mapped = _agg_arena_pool.aligned_alloc (
720+ Base::_parent->template cast <StreamingAggOperatorX>()
721+ ._total_size_of_aggregate_states ,
722+ Base::_parent->template cast <StreamingAggOperatorX>()
723+ ._align_aggregate_states );
724+ auto st = _create_agg_status (mapped);
725+ if (!st) {
726+ throw Exception (st.code (), st.to_string ());
727+ }
728+ _refresh_limit_heap (i, key_columns);
729+ };
730+
731+ SCOPED_TIMER (_hash_table_emplace_timer);
732+ for (i = 0 ; i < num_rows; ++i) {
733+ places[i] = *agg_method.lazy_emplace (state, i, creator,
734+ creator_for_null_key);
735+ }
736+ COUNTER_UPDATE (_hash_table_input_counter, num_rows);
737+ return true ;
738+ }
739+ return false ;
740+ }},
741+ _agg_data->method_variant );
742+ }
743+
744+ bool StreamingAggLocalState::_do_limit_filter (size_t num_rows,
745+ vectorized::ColumnRawPtrs& key_columns) {
746+ SCOPED_TIMER (_hash_table_limit_compute_timer);
747+ if (num_rows) {
748+ cmp_res.resize (num_rows);
749+ need_computes.resize (num_rows);
750+ memset (need_computes.data (), 0 , need_computes.size ());
751+ memset (cmp_res.data (), 0 , cmp_res.size ());
752+
753+ const auto key_size = null_directions.size ();
754+ for (int i = 0 ; i < key_size; i++) {
755+ key_columns[i]->compare_internal (limit_columns_min, *limit_columns[i],
756+ null_directions[i], order_directions[i], cmp_res,
757+ need_computes.data ());
758+ }
759+
760+ auto set_computes_arr = [](auto * __restrict res, auto * __restrict computes, size_t rows) {
761+ for (size_t i = 0 ; i < rows; ++i) {
762+ computes[i] = computes[i] == res[i];
763+ }
764+ };
765+ set_computes_arr (cmp_res.data (), need_computes.data (), num_rows);
766+
767+ return std::find (need_computes.begin (), need_computes.end (), 0 ) != need_computes.end ();
768+ }
769+
770+ return false ;
771+ }
772+
565773void StreamingAggLocalState::_emplace_into_hash_table (vectorized::AggregateDataPtr* places,
566774 vectorized::ColumnRawPtrs& key_columns,
567775 const uint32_t num_rows) {
@@ -617,7 +825,6 @@ StreamingAggOperatorX::StreamingAggOperatorX(ObjectPool* pool, int operator_id,
617825 _intermediate_tuple_id(tnode.agg_node.intermediate_tuple_id),
618826 _output_tuple_id(tnode.agg_node.output_tuple_id),
619827 _needs_finalize(tnode.agg_node.need_finalize),
620- _is_merge(false ),
621828 _is_first_phase(tnode.agg_node.__isset.is_first_phase && tnode.agg_node.is_first_phase),
622829 _have_conjuncts(tnode.__isset.vconjunct && !tnode.vconjunct.nodes.empty()),
623830 _agg_fn_output_row_descriptor(descs, tnode.row_tuples),
@@ -669,8 +876,33 @@ Status StreamingAggOperatorX::init(const TPlanNode& tnode, RuntimeState* state)
669876 }
670877
671878 const auto & agg_functions = tnode.agg_node .aggregate_functions ;
672- _is_merge = std::any_of (agg_functions.cbegin (), agg_functions.cend (),
673- [](const auto & e) { return e.nodes [0 ].agg_expr .is_merge_agg ; });
879+ auto is_merge = std::any_of (agg_functions.cbegin (), agg_functions.cend (),
880+ [](const auto & e) { return e.nodes [0 ].agg_expr .is_merge_agg ; });
881+ if (is_merge || _needs_finalize) {
882+ return Status::InvalidArgument (
883+ " StreamingAggLocalState only support no merge and no finalize, "
884+ " but got is_merge={}, needs_finalize={}" ,
885+ is_merge, _needs_finalize);
886+ }
887+
888+ // Handle sort limit
889+ if (tnode.agg_node .__isset .agg_sort_info_by_group_key ) {
890+ _sort_limit = _limit;
891+ _limit = -1 ;
892+ _do_sort_limit = true ;
893+ const auto & agg_sort_info = tnode.agg_node .agg_sort_info_by_group_key ;
894+ DCHECK_EQ (agg_sort_info.nulls_first .size (), agg_sort_info.is_asc_order .size ());
895+
896+ const size_t order_by_key_size = agg_sort_info.is_asc_order .size ();
897+ _order_directions.resize (order_by_key_size);
898+ _null_directions.resize (order_by_key_size);
899+ for (int i = 0 ; i < order_by_key_size; ++i) {
900+ _order_directions[i] = agg_sort_info.is_asc_order [i] ? 1 : -1 ;
901+ _null_directions[i] =
902+ agg_sort_info.nulls_first [i] ? -_order_directions[i] : _order_directions[i];
903+ }
904+ }
905+
674906 _op_name = " STREAMING_AGGREGATION_OPERATOR" ;
675907 return Status::OK ();
676908}
0 commit comments