Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 250 additions & 18 deletions be/src/pipeline/exec/streaming_aggregation_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ Status StreamingAggLocalState::init(RuntimeState* state, LocalStateInfo& info) {
_insert_values_to_column_timer = ADD_TIMER(Base::custom_profile(), "InsertValuesToColumnTime");
_deserialize_data_timer = ADD_TIMER(Base::custom_profile(), "DeserializeAndMergeTime");
_hash_table_compute_timer = ADD_TIMER(Base::custom_profile(), "HashTableComputeTime");
_hash_table_limit_compute_timer =
ADD_TIMER(Base::custom_profile(), "HashTableLimitComputeTime");
_hash_table_emplace_timer = ADD_TIMER(Base::custom_profile(), "HashTableEmplaceTime");
_hash_table_input_counter =
ADD_COUNTER(Base::custom_profile(), "HashTableInputCount", TUnit::UNIT);
Expand Down Expand Up @@ -153,16 +155,10 @@ Status StreamingAggLocalState::open(RuntimeState* state) {
}},
_agg_data->method_variant);

if (p._is_merge || p._needs_finalize) {
return Status::InvalidArgument(
"StreamingAggLocalState only support no merge and no finalize, "
"but got is_merge={}, needs_finalize={}",
p._is_merge, p._needs_finalize);
}

_should_limit_output = p._limit != -1 && // has limit
(!p._have_conjuncts) && // no having conjunct
p._needs_finalize; // agg's finalize step
limit = p._sort_limit;
do_sort_limit = p._do_sort_limit;
null_directions = p._null_directions;
order_directions = p._order_directions;

return Status::OK();
}
Expand Down Expand Up @@ -364,6 +360,30 @@ Status StreamingAggLocalState::_pre_agg_with_serialized_key(doris::vectorized::B
_places.resize(rows);

if (_should_not_do_pre_agg(rows)) {
if (limit > 0) {
DCHECK(do_sort_limit);
if (need_do_sort_limit == -1) {
const size_t hash_table_size = _get_hash_table_size();
need_do_sort_limit = hash_table_size >= limit ? 1 : 0;
if (need_do_sort_limit == 1) {
build_limit_heap(hash_table_size);
}
}

if (need_do_sort_limit == 1) {
if (_do_limit_filter(rows, key_columns)) {
bool need_filter = std::find(need_computes.begin(), need_computes.end(), 1) !=
need_computes.end();
if (need_filter) {
_add_limit_heap_top(key_columns, rows);
vectorized::Block::filter_block_internal(in_block, need_computes);
rows = (uint32_t)in_block->rows();
} else {
return Status::OK();
}
}
}
}
bool mem_reuse = p._make_nullable_keys.empty() && out_block->mem_reuse();

std::vector<vectorized::DataTypePtr> data_types;
Expand Down Expand Up @@ -405,12 +425,23 @@ Status StreamingAggLocalState::_pre_agg_with_serialized_key(doris::vectorized::B
}
}
} else {
_emplace_into_hash_table(_places.data(), key_columns, rows);
bool need_agg = true;
if (need_do_sort_limit != 1) {
_emplace_into_hash_table(_places.data(), key_columns, rows);
} else {
need_agg = _emplace_into_hash_table_limit(_places.data(), in_block, key_columns, rows);
}

for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add(
in_block, p._offsets_of_aggregate_states[i], _places.data(), _agg_arena_pool,
_should_expand_hash_table));
if (need_agg) {
for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add(
in_block, p._offsets_of_aggregate_states[i], _places.data(),
_agg_arena_pool, _should_expand_hash_table));
}
if (limit > 0 && need_do_sort_limit == -1 && _get_hash_table_size() >= limit) {
need_do_sort_limit = 1;
build_limit_heap(_get_hash_table_size());
}
}
}

Expand Down Expand Up @@ -562,6 +593,183 @@ void StreamingAggLocalState::_destroy_agg_status(vectorized::AggregateDataPtr da
}
}

vectorized::MutableColumns StreamingAggLocalState::_get_keys_hash_table() {
return std::visit(
vectorized::Overload {
[&](std::monostate& arg) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table");
return vectorized::MutableColumns();
},
[&](auto&& agg_method) -> vectorized::MutableColumns {
vectorized::MutableColumns key_columns;
for (int i = 0; i < _probe_expr_ctxs.size(); ++i) {
key_columns.emplace_back(
_probe_expr_ctxs[i]->root()->data_type()->create_column());
}
auto& data = *agg_method.hash_table;
bool has_null_key = data.has_null_key_data();
const auto size = data.size() - has_null_key;
using KeyType = std::decay_t<decltype(agg_method)>::Key;
std::vector<KeyType> keys(size);

uint32_t num_rows = 0;
auto iter = _aggregate_data_container->begin();
{
while (iter != _aggregate_data_container->end()) {
keys[num_rows] = iter.get_key<KeyType>();
++iter;
++num_rows;
}
}
agg_method.insert_keys_into_columns(keys, key_columns, num_rows);
if (has_null_key) {
key_columns[0]->insert_data(nullptr, 0);
}
return key_columns;
}},
_agg_data->method_variant);
}

void StreamingAggLocalState::build_limit_heap(size_t hash_table_size) {
limit_columns = _get_keys_hash_table();
for (size_t i = 0; i < hash_table_size; ++i) {
limit_heap.emplace(i, limit_columns, order_directions, null_directions);
}
while (hash_table_size > limit) {
limit_heap.pop();
hash_table_size--;
}
limit_columns_min = limit_heap.top()._row_id;
}

void StreamingAggLocalState::_add_limit_heap_top(vectorized::ColumnRawPtrs& key_columns,
size_t rows) {
for (int i = 0; i < rows; ++i) {
if (cmp_res[i] == 1 && need_computes[i]) {
for (int j = 0; j < key_columns.size(); ++j) {
limit_columns[j]->insert_from(*key_columns[j], i);
}
limit_heap.emplace(limit_columns[0]->size() - 1, limit_columns, order_directions,
null_directions);
limit_heap.pop();
limit_columns_min = limit_heap.top()._row_id;
break;
}
}
}

void StreamingAggLocalState::_refresh_limit_heap(size_t i, vectorized::ColumnRawPtrs& key_columns) {
for (int j = 0; j < key_columns.size(); ++j) {
limit_columns[j]->insert_from(*key_columns[j], i);
}
limit_heap.emplace(limit_columns[0]->size() - 1, limit_columns, order_directions,
null_directions);
limit_heap.pop();
limit_columns_min = limit_heap.top()._row_id;
}

bool StreamingAggLocalState::_emplace_into_hash_table_limit(vectorized::AggregateDataPtr* places,
vectorized::Block* block,
vectorized::ColumnRawPtrs& key_columns,
uint32_t num_rows) {
return std::visit(
vectorized::Overload {
[&](std::monostate& arg) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table");
return true;
},
[&](auto&& agg_method) -> bool {
SCOPED_TIMER(_hash_table_compute_timer);
using HashMethodType = std::decay_t<decltype(agg_method)>;
using AggState = typename HashMethodType::State;

bool need_filter = _do_limit_filter(num_rows, key_columns);
if (auto need_agg =
std::find(need_computes.begin(), need_computes.end(), 1);
need_agg != need_computes.end()) {
if (need_filter) {
vectorized::Block::filter_block_internal(block, need_computes);
num_rows = (uint32_t)block->rows();
}

AggState state(key_columns);
agg_method.init_serialized_keys(key_columns, num_rows);
size_t i = 0;

auto creator = [&](const auto& ctor, auto& key, auto& origin) {
try {
HashMethodType::try_presis_key_and_origin(key, origin,
_agg_arena_pool);
auto mapped = _aggregate_data_container->append_data(origin);
auto st = _create_agg_status(mapped);
if (!st) {
throw Exception(st.code(), st.to_string());
}
ctor(key, mapped);
_refresh_limit_heap(i, key_columns);
} catch (...) {
// Exception-safety - if it can not allocate memory or create status,
// the destructors will not be called.
ctor(key, nullptr);
throw;
}
};

auto creator_for_null_key = [&](auto& mapped) {
mapped = _agg_arena_pool.aligned_alloc(
Base::_parent->template cast<StreamingAggOperatorX>()
._total_size_of_aggregate_states,
Base::_parent->template cast<StreamingAggOperatorX>()
._align_aggregate_states);
auto st = _create_agg_status(mapped);
if (!st) {
throw Exception(st.code(), st.to_string());
}
_refresh_limit_heap(i, key_columns);
};

SCOPED_TIMER(_hash_table_emplace_timer);
for (i = 0; i < num_rows; ++i) {
places[i] = *agg_method.lazy_emplace(state, i, creator,
creator_for_null_key);
}
COUNTER_UPDATE(_hash_table_input_counter, num_rows);
return true;
}
return false;
}},
_agg_data->method_variant);
}

bool StreamingAggLocalState::_do_limit_filter(size_t num_rows,
vectorized::ColumnRawPtrs& key_columns) {
SCOPED_TIMER(_hash_table_limit_compute_timer);
if (num_rows) {
cmp_res.resize(num_rows);
need_computes.resize(num_rows);
memset(need_computes.data(), 0, need_computes.size());
memset(cmp_res.data(), 0, cmp_res.size());

const auto key_size = null_directions.size();
for (int i = 0; i < key_size; i++) {
key_columns[i]->compare_internal(limit_columns_min, *limit_columns[i],
null_directions[i], order_directions[i], cmp_res,
need_computes.data());
}

auto set_computes_arr = [](auto* __restrict res, auto* __restrict computes, size_t rows) {
for (size_t i = 0; i < rows; ++i) {
computes[i] = computes[i] == res[i];
}
};
set_computes_arr(cmp_res.data(), need_computes.data(), num_rows);

return std::find(need_computes.begin(), need_computes.end(), 0) != need_computes.end();
}

return false;
}

void StreamingAggLocalState::_emplace_into_hash_table(vectorized::AggregateDataPtr* places,
vectorized::ColumnRawPtrs& key_columns,
const uint32_t num_rows) {
Expand Down Expand Up @@ -617,7 +825,6 @@ StreamingAggOperatorX::StreamingAggOperatorX(ObjectPool* pool, int operator_id,
_intermediate_tuple_id(tnode.agg_node.intermediate_tuple_id),
_output_tuple_id(tnode.agg_node.output_tuple_id),
_needs_finalize(tnode.agg_node.need_finalize),
_is_merge(false),
_is_first_phase(tnode.agg_node.__isset.is_first_phase && tnode.agg_node.is_first_phase),
_have_conjuncts(tnode.__isset.vconjunct && !tnode.vconjunct.nodes.empty()),
_agg_fn_output_row_descriptor(descs, tnode.row_tuples),
Expand Down Expand Up @@ -669,8 +876,33 @@ Status StreamingAggOperatorX::init(const TPlanNode& tnode, RuntimeState* state)
}

const auto& agg_functions = tnode.agg_node.aggregate_functions;
_is_merge = std::any_of(agg_functions.cbegin(), agg_functions.cend(),
[](const auto& e) { return e.nodes[0].agg_expr.is_merge_agg; });
auto is_merge = std::any_of(agg_functions.cbegin(), agg_functions.cend(),
[](const auto& e) { return e.nodes[0].agg_expr.is_merge_agg; });
if (is_merge || _needs_finalize) {
return Status::InvalidArgument(
"StreamingAggLocalState only support no merge and no finalize, "
"but got is_merge={}, needs_finalize={}",
is_merge, _needs_finalize);
}

// Handle sort limit
if (tnode.agg_node.__isset.agg_sort_info_by_group_key) {
_sort_limit = _limit;
_limit = -1;
_do_sort_limit = true;
const auto& agg_sort_info = tnode.agg_node.agg_sort_info_by_group_key;
DCHECK_EQ(agg_sort_info.nulls_first.size(), agg_sort_info.is_asc_order.size());

const size_t order_by_key_size = agg_sort_info.is_asc_order.size();
_order_directions.resize(order_by_key_size);
_null_directions.resize(order_by_key_size);
for (int i = 0; i < order_by_key_size; ++i) {
_order_directions[i] = agg_sort_info.is_asc_order[i] ? 1 : -1;
_null_directions[i] =
agg_sort_info.nulls_first[i] ? -_order_directions[i] : _order_directions[i];
}
}

_op_name = "STREAMING_AGGREGATION_OPERATOR";
return Status::OK();
}
Expand Down
Loading
Loading