Skip to content

Commit 3ead2eb

Browse files
authored
[Exec] (performance) support stream agg topn (#59446)
### What problem does this PR solve? Before query cost: 2s after query cost: 1s
1 parent d86daef commit 3ead2eb

File tree

5 files changed

+329
-30
lines changed

5 files changed

+329
-30
lines changed

be/src/pipeline/exec/streaming_aggregation_operator.cpp

Lines changed: 250 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ Status StreamingAggLocalState::init(RuntimeState* state, LocalStateInfo& info) {
100100
_insert_values_to_column_timer = ADD_TIMER(Base::custom_profile(), "InsertValuesToColumnTime");
101101
_deserialize_data_timer = ADD_TIMER(Base::custom_profile(), "DeserializeAndMergeTime");
102102
_hash_table_compute_timer = ADD_TIMER(Base::custom_profile(), "HashTableComputeTime");
103+
_hash_table_limit_compute_timer =
104+
ADD_TIMER(Base::custom_profile(), "HashTableLimitComputeTime");
103105
_hash_table_emplace_timer = ADD_TIMER(Base::custom_profile(), "HashTableEmplaceTime");
104106
_hash_table_input_counter =
105107
ADD_COUNTER(Base::custom_profile(), "HashTableInputCount", TUnit::UNIT);
@@ -153,16 +155,10 @@ Status StreamingAggLocalState::open(RuntimeState* state) {
153155
}},
154156
_agg_data->method_variant);
155157

156-
if (p._is_merge || p._needs_finalize) {
157-
return Status::InvalidArgument(
158-
"StreamingAggLocalState only support no merge and no finalize, "
159-
"but got is_merge={}, needs_finalize={}",
160-
p._is_merge, p._needs_finalize);
161-
}
162-
163-
_should_limit_output = p._limit != -1 && // has limit
164-
(!p._have_conjuncts) && // no having conjunct
165-
p._needs_finalize; // agg's finalize step
158+
limit = p._sort_limit;
159+
do_sort_limit = p._do_sort_limit;
160+
null_directions = p._null_directions;
161+
order_directions = p._order_directions;
166162

167163
return Status::OK();
168164
}
@@ -364,6 +360,30 @@ Status StreamingAggLocalState::_pre_agg_with_serialized_key(doris::vectorized::B
364360
_places.resize(rows);
365361

366362
if (_should_not_do_pre_agg(rows)) {
363+
if (limit > 0) {
364+
DCHECK(do_sort_limit);
365+
if (need_do_sort_limit == -1) {
366+
const size_t hash_table_size = _get_hash_table_size();
367+
need_do_sort_limit = hash_table_size >= limit ? 1 : 0;
368+
if (need_do_sort_limit == 1) {
369+
build_limit_heap(hash_table_size);
370+
}
371+
}
372+
373+
if (need_do_sort_limit == 1) {
374+
if (_do_limit_filter(rows, key_columns)) {
375+
bool need_filter = std::find(need_computes.begin(), need_computes.end(), 1) !=
376+
need_computes.end();
377+
if (need_filter) {
378+
_add_limit_heap_top(key_columns, rows);
379+
vectorized::Block::filter_block_internal(in_block, need_computes);
380+
rows = (uint32_t)in_block->rows();
381+
} else {
382+
return Status::OK();
383+
}
384+
}
385+
}
386+
}
367387
bool mem_reuse = p._make_nullable_keys.empty() && out_block->mem_reuse();
368388

369389
std::vector<vectorized::DataTypePtr> data_types;
@@ -405,12 +425,23 @@ Status StreamingAggLocalState::_pre_agg_with_serialized_key(doris::vectorized::B
405425
}
406426
}
407427
} else {
408-
_emplace_into_hash_table(_places.data(), key_columns, rows);
428+
bool need_agg = true;
429+
if (need_do_sort_limit != 1) {
430+
_emplace_into_hash_table(_places.data(), key_columns, rows);
431+
} else {
432+
need_agg = _emplace_into_hash_table_limit(_places.data(), in_block, key_columns, rows);
433+
}
409434

410-
for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
411-
RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add(
412-
in_block, p._offsets_of_aggregate_states[i], _places.data(), _agg_arena_pool,
413-
_should_expand_hash_table));
435+
if (need_agg) {
436+
for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
437+
RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add(
438+
in_block, p._offsets_of_aggregate_states[i], _places.data(),
439+
_agg_arena_pool, _should_expand_hash_table));
440+
}
441+
if (limit > 0 && need_do_sort_limit == -1 && _get_hash_table_size() >= limit) {
442+
need_do_sort_limit = 1;
443+
build_limit_heap(_get_hash_table_size());
444+
}
414445
}
415446
}
416447

@@ -562,6 +593,183 @@ void StreamingAggLocalState::_destroy_agg_status(vectorized::AggregateDataPtr da
562593
}
563594
}
564595

596+
vectorized::MutableColumns StreamingAggLocalState::_get_keys_hash_table() {
597+
return std::visit(
598+
vectorized::Overload {
599+
[&](std::monostate& arg) {
600+
throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table");
601+
return vectorized::MutableColumns();
602+
},
603+
[&](auto&& agg_method) -> vectorized::MutableColumns {
604+
vectorized::MutableColumns key_columns;
605+
for (int i = 0; i < _probe_expr_ctxs.size(); ++i) {
606+
key_columns.emplace_back(
607+
_probe_expr_ctxs[i]->root()->data_type()->create_column());
608+
}
609+
auto& data = *agg_method.hash_table;
610+
bool has_null_key = data.has_null_key_data();
611+
const auto size = data.size() - has_null_key;
612+
using KeyType = std::decay_t<decltype(agg_method)>::Key;
613+
std::vector<KeyType> keys(size);
614+
615+
uint32_t num_rows = 0;
616+
auto iter = _aggregate_data_container->begin();
617+
{
618+
while (iter != _aggregate_data_container->end()) {
619+
keys[num_rows] = iter.get_key<KeyType>();
620+
++iter;
621+
++num_rows;
622+
}
623+
}
624+
agg_method.insert_keys_into_columns(keys, key_columns, num_rows);
625+
if (has_null_key) {
626+
key_columns[0]->insert_data(nullptr, 0);
627+
}
628+
return key_columns;
629+
}},
630+
_agg_data->method_variant);
631+
}
632+
633+
void StreamingAggLocalState::build_limit_heap(size_t hash_table_size) {
634+
limit_columns = _get_keys_hash_table();
635+
for (size_t i = 0; i < hash_table_size; ++i) {
636+
limit_heap.emplace(i, limit_columns, order_directions, null_directions);
637+
}
638+
while (hash_table_size > limit) {
639+
limit_heap.pop();
640+
hash_table_size--;
641+
}
642+
limit_columns_min = limit_heap.top()._row_id;
643+
}
644+
645+
void StreamingAggLocalState::_add_limit_heap_top(vectorized::ColumnRawPtrs& key_columns,
646+
size_t rows) {
647+
for (int i = 0; i < rows; ++i) {
648+
if (cmp_res[i] == 1 && need_computes[i]) {
649+
for (int j = 0; j < key_columns.size(); ++j) {
650+
limit_columns[j]->insert_from(*key_columns[j], i);
651+
}
652+
limit_heap.emplace(limit_columns[0]->size() - 1, limit_columns, order_directions,
653+
null_directions);
654+
limit_heap.pop();
655+
limit_columns_min = limit_heap.top()._row_id;
656+
break;
657+
}
658+
}
659+
}
660+
661+
void StreamingAggLocalState::_refresh_limit_heap(size_t i, vectorized::ColumnRawPtrs& key_columns) {
662+
for (int j = 0; j < key_columns.size(); ++j) {
663+
limit_columns[j]->insert_from(*key_columns[j], i);
664+
}
665+
limit_heap.emplace(limit_columns[0]->size() - 1, limit_columns, order_directions,
666+
null_directions);
667+
limit_heap.pop();
668+
limit_columns_min = limit_heap.top()._row_id;
669+
}
670+
671+
bool StreamingAggLocalState::_emplace_into_hash_table_limit(vectorized::AggregateDataPtr* places,
672+
vectorized::Block* block,
673+
vectorized::ColumnRawPtrs& key_columns,
674+
uint32_t num_rows) {
675+
return std::visit(
676+
vectorized::Overload {
677+
[&](std::monostate& arg) {
678+
throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table");
679+
return true;
680+
},
681+
[&](auto&& agg_method) -> bool {
682+
SCOPED_TIMER(_hash_table_compute_timer);
683+
using HashMethodType = std::decay_t<decltype(agg_method)>;
684+
using AggState = typename HashMethodType::State;
685+
686+
bool need_filter = _do_limit_filter(num_rows, key_columns);
687+
if (auto need_agg =
688+
std::find(need_computes.begin(), need_computes.end(), 1);
689+
need_agg != need_computes.end()) {
690+
if (need_filter) {
691+
vectorized::Block::filter_block_internal(block, need_computes);
692+
num_rows = (uint32_t)block->rows();
693+
}
694+
695+
AggState state(key_columns);
696+
agg_method.init_serialized_keys(key_columns, num_rows);
697+
size_t i = 0;
698+
699+
auto creator = [&](const auto& ctor, auto& key, auto& origin) {
700+
try {
701+
HashMethodType::try_presis_key_and_origin(key, origin,
702+
_agg_arena_pool);
703+
auto mapped = _aggregate_data_container->append_data(origin);
704+
auto st = _create_agg_status(mapped);
705+
if (!st) {
706+
throw Exception(st.code(), st.to_string());
707+
}
708+
ctor(key, mapped);
709+
_refresh_limit_heap(i, key_columns);
710+
} catch (...) {
711+
// Exception-safety - if it can not allocate memory or create status,
712+
// the destructors will not be called.
713+
ctor(key, nullptr);
714+
throw;
715+
}
716+
};
717+
718+
auto creator_for_null_key = [&](auto& mapped) {
719+
mapped = _agg_arena_pool.aligned_alloc(
720+
Base::_parent->template cast<StreamingAggOperatorX>()
721+
._total_size_of_aggregate_states,
722+
Base::_parent->template cast<StreamingAggOperatorX>()
723+
._align_aggregate_states);
724+
auto st = _create_agg_status(mapped);
725+
if (!st) {
726+
throw Exception(st.code(), st.to_string());
727+
}
728+
_refresh_limit_heap(i, key_columns);
729+
};
730+
731+
SCOPED_TIMER(_hash_table_emplace_timer);
732+
for (i = 0; i < num_rows; ++i) {
733+
places[i] = *agg_method.lazy_emplace(state, i, creator,
734+
creator_for_null_key);
735+
}
736+
COUNTER_UPDATE(_hash_table_input_counter, num_rows);
737+
return true;
738+
}
739+
return false;
740+
}},
741+
_agg_data->method_variant);
742+
}
743+
744+
bool StreamingAggLocalState::_do_limit_filter(size_t num_rows,
745+
vectorized::ColumnRawPtrs& key_columns) {
746+
SCOPED_TIMER(_hash_table_limit_compute_timer);
747+
if (num_rows) {
748+
cmp_res.resize(num_rows);
749+
need_computes.resize(num_rows);
750+
memset(need_computes.data(), 0, need_computes.size());
751+
memset(cmp_res.data(), 0, cmp_res.size());
752+
753+
const auto key_size = null_directions.size();
754+
for (int i = 0; i < key_size; i++) {
755+
key_columns[i]->compare_internal(limit_columns_min, *limit_columns[i],
756+
null_directions[i], order_directions[i], cmp_res,
757+
need_computes.data());
758+
}
759+
760+
auto set_computes_arr = [](auto* __restrict res, auto* __restrict computes, size_t rows) {
761+
for (size_t i = 0; i < rows; ++i) {
762+
computes[i] = computes[i] == res[i];
763+
}
764+
};
765+
set_computes_arr(cmp_res.data(), need_computes.data(), num_rows);
766+
767+
return std::find(need_computes.begin(), need_computes.end(), 0) != need_computes.end();
768+
}
769+
770+
return false;
771+
}
772+
565773
void StreamingAggLocalState::_emplace_into_hash_table(vectorized::AggregateDataPtr* places,
566774
vectorized::ColumnRawPtrs& key_columns,
567775
const uint32_t num_rows) {
@@ -617,7 +825,6 @@ StreamingAggOperatorX::StreamingAggOperatorX(ObjectPool* pool, int operator_id,
617825
_intermediate_tuple_id(tnode.agg_node.intermediate_tuple_id),
618826
_output_tuple_id(tnode.agg_node.output_tuple_id),
619827
_needs_finalize(tnode.agg_node.need_finalize),
620-
_is_merge(false),
621828
_is_first_phase(tnode.agg_node.__isset.is_first_phase && tnode.agg_node.is_first_phase),
622829
_have_conjuncts(tnode.__isset.vconjunct && !tnode.vconjunct.nodes.empty()),
623830
_agg_fn_output_row_descriptor(descs, tnode.row_tuples),
@@ -669,8 +876,33 @@ Status StreamingAggOperatorX::init(const TPlanNode& tnode, RuntimeState* state)
669876
}
670877

671878
const auto& agg_functions = tnode.agg_node.aggregate_functions;
672-
_is_merge = std::any_of(agg_functions.cbegin(), agg_functions.cend(),
673-
[](const auto& e) { return e.nodes[0].agg_expr.is_merge_agg; });
879+
auto is_merge = std::any_of(agg_functions.cbegin(), agg_functions.cend(),
880+
[](const auto& e) { return e.nodes[0].agg_expr.is_merge_agg; });
881+
if (is_merge || _needs_finalize) {
882+
return Status::InvalidArgument(
883+
"StreamingAggLocalState only support no merge and no finalize, "
884+
"but got is_merge={}, needs_finalize={}",
885+
is_merge, _needs_finalize);
886+
}
887+
888+
// Handle sort limit
889+
if (tnode.agg_node.__isset.agg_sort_info_by_group_key) {
890+
_sort_limit = _limit;
891+
_limit = -1;
892+
_do_sort_limit = true;
893+
const auto& agg_sort_info = tnode.agg_node.agg_sort_info_by_group_key;
894+
DCHECK_EQ(agg_sort_info.nulls_first.size(), agg_sort_info.is_asc_order.size());
895+
896+
const size_t order_by_key_size = agg_sort_info.is_asc_order.size();
897+
_order_directions.resize(order_by_key_size);
898+
_null_directions.resize(order_by_key_size);
899+
for (int i = 0; i < order_by_key_size; ++i) {
900+
_order_directions[i] = agg_sort_info.is_asc_order[i] ? 1 : -1;
901+
_null_directions[i] =
902+
agg_sort_info.nulls_first[i] ? -_order_directions[i] : _order_directions[i];
903+
}
904+
}
905+
674906
_op_name = "STREAMING_AGGREGATION_OPERATOR";
675907
return Status::OK();
676908
}

0 commit comments

Comments
 (0)