Skip to content

Commit afc9ac1

Browse files
HappenLeezclllyybb
authored andcommitted
support stream agg topn
1 parent 2771e6a commit afc9ac1

File tree

3 files changed

+344
-39
lines changed

3 files changed

+344
-39
lines changed

be/src/pipeline/exec/streaming_aggregation_operator.cpp

Lines changed: 266 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ Status StreamingAggLocalState::init(RuntimeState* state, LocalStateInfo& info) {
9999
_insert_values_to_column_timer = ADD_TIMER(Base::custom_profile(), "InsertValuesToColumnTime");
100100
_deserialize_data_timer = ADD_TIMER(Base::custom_profile(), "DeserializeAndMergeTime");
101101
_hash_table_compute_timer = ADD_TIMER(Base::custom_profile(), "HashTableComputeTime");
102+
_hash_table_limit_compute_timer =
103+
ADD_TIMER(Base::custom_profile(), "HashTableLimitComputeTime");
102104
_hash_table_emplace_timer = ADD_TIMER(Base::custom_profile(), "HashTableEmplaceTime");
103105
_hash_table_input_counter =
104106
ADD_COUNTER(Base::custom_profile(), "HashTableInputCount", TUnit::UNIT);
@@ -152,16 +154,10 @@ Status StreamingAggLocalState::open(RuntimeState* state) {
152154
}},
153155
_agg_data->method_variant);
154156

155-
if (p._is_merge || p._needs_finalize) {
156-
return Status::InvalidArgument(
157-
"StreamingAggLocalState only support no merge and no finalize, "
158-
"but got is_merge={}, needs_finalize={}",
159-
p._is_merge, p._needs_finalize);
160-
}
161-
162-
_should_limit_output = p._limit != -1 && // has limit
163-
(!p._have_conjuncts) && // no having conjunct
164-
p._needs_finalize; // agg's finalize step
157+
limit = p._sort_limit;
158+
do_sort_limit = p._do_sort_limit;
159+
null_directions = p._null_directions;
160+
order_directions = p._order_directions;
165161

166162
return Status::OK();
167163
}
@@ -316,23 +312,22 @@ bool StreamingAggLocalState::_should_not_do_pre_agg(size_t rows) {
316312
const auto spill_streaming_agg_mem_limit = p._spill_streaming_agg_mem_limit;
317313
const bool used_too_much_memory =
318314
spill_streaming_agg_mem_limit > 0 && _memory_usage() > spill_streaming_agg_mem_limit;
319-
std::visit(
320-
vectorized::Overload {
321-
[&](std::monostate& arg) {
322-
throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table");
323-
},
324-
[&](auto& agg_method) {
325-
auto& hash_tbl = *agg_method.hash_table;
326-
/// If too much memory is used during the pre-aggregation stage,
327-
/// it is better to output the data directly without performing further aggregation.
328-
// do not try to do agg, just init and serialize directly return the out_block
329-
if (used_too_much_memory || (hash_tbl.add_elem_size_overflow(rows) &&
330-
!_should_expand_preagg_hash_tables())) {
331-
SCOPED_TIMER(_streaming_agg_timer);
332-
ret_flag = true;
333-
}
334-
}},
335-
_agg_data->method_variant);
315+
std::visit(vectorized::Overload {
316+
[&](std::monostate& arg) {
317+
throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table");
318+
},
319+
[&](auto& agg_method) {
320+
auto& hash_tbl = *agg_method.hash_table;
321+
/// If too much memory is used during the pre-aggregation stage,
322+
/// it is better to output the data directly without performing further aggregation.
323+
// do not try to do agg, just init and serialize directly return the out_block
324+
if (used_too_much_memory || (hash_tbl.add_elem_size_overflow(rows) &&
325+
!_should_expand_preagg_hash_tables())) {
326+
SCOPED_TIMER(_streaming_agg_timer);
327+
ret_flag = true;
328+
}
329+
}},
330+
_agg_data->method_variant);
336331

337332
return ret_flag;
338333
}
@@ -363,6 +358,30 @@ Status StreamingAggLocalState::_pre_agg_with_serialized_key(doris::vectorized::B
363358
_places.resize(rows);
364359

365360
if (_should_not_do_pre_agg(rows)) {
361+
if (limit > 0) {
362+
DCHECK(do_sort_limit);
363+
if (need_do_sort_limit == -1) {
364+
const size_t hash_table_size = _get_hash_table_size();
365+
need_do_sort_limit = hash_table_size >= limit ? 1 : 0;
366+
if (need_do_sort_limit == 1) {
367+
build_limit_heap(hash_table_size);
368+
}
369+
}
370+
371+
if (need_do_sort_limit == 1) {
372+
if (_do_limit_filter(rows, key_columns)) {
373+
bool need_filter = std::find(need_computes.begin(), need_computes.end(), 1) !=
374+
need_computes.end();
375+
if (need_filter) {
376+
_add_limit_heap_top(key_columns, rows);
377+
vectorized::Block::filter_block_internal(in_block, need_computes);
378+
rows = (uint32_t)in_block->rows();
379+
} else {
380+
return Status::OK();
381+
}
382+
}
383+
}
384+
}
366385
bool mem_reuse = p._make_nullable_keys.empty() && out_block->mem_reuse();
367386

368387
std::vector<vectorized::DataTypePtr> data_types;
@@ -404,12 +423,23 @@ Status StreamingAggLocalState::_pre_agg_with_serialized_key(doris::vectorized::B
404423
}
405424
}
406425
} else {
407-
_emplace_into_hash_table(_places.data(), key_columns, rows);
426+
bool need_agg = true;
427+
if (need_do_sort_limit != 1) {
428+
_emplace_into_hash_table(_places.data(), key_columns, rows);
429+
} else {
430+
need_agg = _emplace_into_hash_table_limit(_places.data(), in_block, key_columns, rows);
431+
}
408432

409-
for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
410-
RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add(
411-
in_block, p._offsets_of_aggregate_states[i], _places.data(), _agg_arena_pool,
412-
_should_expand_hash_table));
433+
if (need_agg) {
434+
for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
435+
RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add(
436+
in_block, p._offsets_of_aggregate_states[i], _places.data(),
437+
_agg_arena_pool, _should_expand_hash_table));
438+
}
439+
if (limit > 0 && need_do_sort_limit == -1 && _get_hash_table_size() >= limit) {
440+
need_do_sort_limit = 1;
441+
build_limit_heap(_get_hash_table_size());
442+
}
413443
}
414444
}
415445

@@ -561,6 +591,183 @@ void StreamingAggLocalState::_destroy_agg_status(vectorized::AggregateDataPtr da
561591
}
562592
}
563593

594+
vectorized::MutableColumns StreamingAggLocalState::_get_keys_hash_table() {
595+
return std::visit(
596+
vectorized::Overload {
597+
[&](std::monostate& arg) {
598+
throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table");
599+
return vectorized::MutableColumns();
600+
},
601+
[&](auto&& agg_method) -> vectorized::MutableColumns {
602+
vectorized::MutableColumns key_columns;
603+
for (int i = 0; i < _probe_expr_ctxs.size(); ++i) {
604+
key_columns.emplace_back(
605+
_probe_expr_ctxs[i]->root()->data_type()->create_column());
606+
}
607+
auto& data = *agg_method.hash_table;
608+
bool has_null_key = data.has_null_key_data();
609+
const auto size = data.size() - has_null_key;
610+
using KeyType = std::decay_t<decltype(agg_method)>::Key;
611+
std::vector<KeyType> keys(size);
612+
613+
uint32_t num_rows = 0;
614+
auto iter = _aggregate_data_container->begin();
615+
{
616+
while (iter != _aggregate_data_container->end()) {
617+
keys[num_rows] = iter.get_key<KeyType>();
618+
++iter;
619+
++num_rows;
620+
}
621+
}
622+
agg_method.insert_keys_into_columns(keys, key_columns, num_rows);
623+
if (has_null_key) {
624+
key_columns[0]->insert_data(nullptr, 0);
625+
}
626+
return key_columns;
627+
}},
628+
_agg_data->method_variant);
629+
}
630+
631+
void StreamingAggLocalState::build_limit_heap(size_t hash_table_size) {
632+
limit_columns = _get_keys_hash_table();
633+
for (size_t i = 0; i < hash_table_size; ++i) {
634+
limit_heap.emplace(i, limit_columns, order_directions, null_directions);
635+
}
636+
while (hash_table_size > limit) {
637+
limit_heap.pop();
638+
hash_table_size--;
639+
}
640+
limit_columns_min = limit_heap.top()._row_id;
641+
}
642+
643+
void StreamingAggLocalState::_add_limit_heap_top(vectorized::ColumnRawPtrs& key_columns,
644+
size_t rows) {
645+
for (int i = 0; i < rows; ++i) {
646+
if (cmp_res[i] == 1 && need_computes[i]) {
647+
for (int j = 0; j < key_columns.size(); ++j) {
648+
limit_columns[j]->insert_from(*key_columns[j], i);
649+
}
650+
limit_heap.emplace(limit_columns[0]->size() - 1, limit_columns, order_directions,
651+
null_directions);
652+
limit_heap.pop();
653+
limit_columns_min = limit_heap.top()._row_id;
654+
break;
655+
}
656+
}
657+
}
658+
659+
void StreamingAggLocalState::_refresh_limit_heap(size_t i, vectorized::ColumnRawPtrs& key_columns) {
660+
for (int j = 0; j < key_columns.size(); ++j) {
661+
limit_columns[j]->insert_from(*key_columns[j], i);
662+
}
663+
limit_heap.emplace(limit_columns[0]->size() - 1, limit_columns, order_directions,
664+
null_directions);
665+
limit_heap.pop();
666+
limit_columns_min = limit_heap.top()._row_id;
667+
}
668+
669+
bool StreamingAggLocalState::_emplace_into_hash_table_limit(vectorized::AggregateDataPtr* places,
670+
vectorized::Block* block,
671+
vectorized::ColumnRawPtrs& key_columns,
672+
uint32_t num_rows) {
673+
return std::visit(
674+
vectorized::Overload {
675+
[&](std::monostate& arg) {
676+
throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table");
677+
return true;
678+
},
679+
[&](auto&& agg_method) -> bool {
680+
SCOPED_TIMER(_hash_table_compute_timer);
681+
using HashMethodType = std::decay_t<decltype(agg_method)>;
682+
using AggState = typename HashMethodType::State;
683+
684+
bool need_filter = _do_limit_filter(num_rows, key_columns);
685+
if (auto need_agg =
686+
std::find(need_computes.begin(), need_computes.end(), 1);
687+
need_agg != need_computes.end()) {
688+
if (need_filter) {
689+
vectorized::Block::filter_block_internal(block, need_computes);
690+
num_rows = (uint32_t)block->rows();
691+
}
692+
693+
AggState state(key_columns);
694+
agg_method.init_serialized_keys(key_columns, num_rows);
695+
size_t i = 0;
696+
697+
auto creator = [&](const auto& ctor, auto& key, auto& origin) {
698+
try {
699+
HashMethodType::try_presis_key_and_origin(key, origin,
700+
_agg_arena_pool);
701+
auto mapped = _aggregate_data_container->append_data(origin);
702+
auto st = _create_agg_status(mapped);
703+
if (!st) {
704+
throw Exception(st.code(), st.to_string());
705+
}
706+
ctor(key, mapped);
707+
_refresh_limit_heap(i, key_columns);
708+
} catch (...) {
709+
// Exception-safety - if it can not allocate memory or create status,
710+
// the destructors will not be called.
711+
ctor(key, nullptr);
712+
throw;
713+
}
714+
};
715+
716+
auto creator_for_null_key = [&](auto& mapped) {
717+
mapped = _agg_arena_pool.aligned_alloc(
718+
Base::_parent->template cast<StreamingAggOperatorX>()
719+
._total_size_of_aggregate_states,
720+
Base::_parent->template cast<StreamingAggOperatorX>()
721+
._align_aggregate_states);
722+
auto st = _create_agg_status(mapped);
723+
if (!st) {
724+
throw Exception(st.code(), st.to_string());
725+
}
726+
_refresh_limit_heap(i, key_columns);
727+
};
728+
729+
SCOPED_TIMER(_hash_table_emplace_timer);
730+
for (i = 0; i < num_rows; ++i) {
731+
places[i] = *agg_method.lazy_emplace(state, i, creator,
732+
creator_for_null_key);
733+
}
734+
COUNTER_UPDATE(_hash_table_input_counter, num_rows);
735+
return true;
736+
}
737+
return false;
738+
}},
739+
_agg_data->method_variant);
740+
}
741+
742+
bool StreamingAggLocalState::_do_limit_filter(size_t num_rows,
743+
vectorized::ColumnRawPtrs& key_columns) {
744+
SCOPED_TIMER(_hash_table_limit_compute_timer);
745+
if (num_rows) {
746+
cmp_res.resize(num_rows);
747+
need_computes.resize(num_rows);
748+
memset(need_computes.data(), 0, need_computes.size());
749+
memset(cmp_res.data(), 0, cmp_res.size());
750+
751+
const auto key_size = null_directions.size();
752+
for (int i = 0; i < key_size; i++) {
753+
key_columns[i]->compare_internal(limit_columns_min, *limit_columns[i],
754+
null_directions[i], order_directions[i], cmp_res,
755+
need_computes.data());
756+
}
757+
758+
auto set_computes_arr = [](auto* __restrict res, auto* __restrict computes, size_t rows) {
759+
for (size_t i = 0; i < rows; ++i) {
760+
computes[i] = computes[i] == res[i];
761+
}
762+
};
763+
set_computes_arr(cmp_res.data(), need_computes.data(), num_rows);
764+
765+
return std::find(need_computes.begin(), need_computes.end(), 0) != need_computes.end();
766+
}
767+
768+
return false;
769+
}
770+
564771
void StreamingAggLocalState::_emplace_into_hash_table(vectorized::AggregateDataPtr* places,
565772
vectorized::ColumnRawPtrs& key_columns,
566773
const uint32_t num_rows) {
@@ -616,7 +823,6 @@ StreamingAggOperatorX::StreamingAggOperatorX(ObjectPool* pool, int operator_id,
616823
_intermediate_tuple_id(tnode.agg_node.intermediate_tuple_id),
617824
_output_tuple_id(tnode.agg_node.output_tuple_id),
618825
_needs_finalize(tnode.agg_node.need_finalize),
619-
_is_merge(false),
620826
_is_first_phase(tnode.agg_node.__isset.is_first_phase && tnode.agg_node.is_first_phase),
621827
_have_conjuncts(tnode.__isset.vconjunct && !tnode.vconjunct.nodes.empty()),
622828
_agg_fn_output_row_descriptor(descs, tnode.row_tuples, tnode.nullable_tuples),
@@ -668,8 +874,33 @@ Status StreamingAggOperatorX::init(const TPlanNode& tnode, RuntimeState* state)
668874
}
669875

670876
const auto& agg_functions = tnode.agg_node.aggregate_functions;
671-
_is_merge = std::any_of(agg_functions.cbegin(), agg_functions.cend(),
672-
[](const auto& e) { return e.nodes[0].agg_expr.is_merge_agg; });
877+
auto is_merge = std::any_of(agg_functions.cbegin(), agg_functions.cend(),
878+
[](const auto& e) { return e.nodes[0].agg_expr.is_merge_agg; });
879+
if (is_merge || _needs_finalize) {
880+
return Status::InvalidArgument(
881+
"StreamingAggLocalState only support no merge and no finalize, "
882+
"but got is_merge={}, needs_finalize={}",
883+
is_merge, _needs_finalize);
884+
}
885+
886+
// Handle sort limit
887+
if (tnode.agg_node.__isset.agg_sort_info_by_group_key) {
888+
_sort_limit = _limit;
889+
_limit = -1;
890+
_do_sort_limit = true;
891+
const auto& agg_sort_info = tnode.agg_node.agg_sort_info_by_group_key;
892+
DCHECK_EQ(agg_sort_info.nulls_first.size(), agg_sort_info.is_asc_order.size());
893+
894+
const size_t order_by_key_size = agg_sort_info.is_asc_order.size();
895+
_order_directions.resize(order_by_key_size);
896+
_null_directions.resize(order_by_key_size);
897+
for (int i = 0; i < order_by_key_size; ++i) {
898+
_order_directions[i] = agg_sort_info.is_asc_order[i] ? 1 : -1;
899+
_null_directions[i] =
900+
agg_sort_info.nulls_first[i] ? -_order_directions[i] : _order_directions[i];
901+
}
902+
}
903+
673904
_op_name = "STREAMING_AGGREGATION_OPERATOR";
674905
return Status::OK();
675906
}

0 commit comments

Comments
 (0)