Skip to content

Commit 2ab394e

Browse files
author
amory
authored
branch-3.1: [fix](collect_list) fix collect_list in multi BE will ser/de failed apache#48314 (apache#56752)
picked from apache#48314
1 parent e8c0f85 commit 2ab394e

File tree

4 files changed

+701
-9
lines changed

4 files changed

+701
-9
lines changed

be/src/vec/aggregate_functions/aggregate_function_collect.h

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,13 +312,15 @@ template <typename HasLimit>
312312
struct AggregateFunctionCollectListData<void, HasLimit> {
313313
using ElementType = StringRef;
314314
using Self = AggregateFunctionCollectListData<void, HasLimit>;
315+
DataTypeSerDeSPtr serde; // for complex serialize && deserialize from multi BE
315316
MutableColumnPtr column_data;
316317
Int64 max_size = -1;
317318

318319
AggregateFunctionCollectListData() {}
319320
AggregateFunctionCollectListData(const DataTypes& argument_types) {
320321
DataTypePtr column_type = argument_types[0];
321322
column_data = column_type->create_column();
323+
serde = column_type->get_serde();
322324
}
323325

324326
size_t size() const { return column_data->size(); }
@@ -345,21 +347,41 @@ struct AggregateFunctionCollectListData<void, HasLimit> {
345347
void write(BufferWritable& buf) const {
346348
const size_t size = column_data->size();
347349
write_binary(size, buf);
350+
351+
DataTypeSerDe::FormatOptions opt;
352+
auto tmp_str = ColumnString::create();
353+
VectorBufferWriter tmp_buf(*tmp_str.get());
354+
348355
for (size_t i = 0; i < size; i++) {
349-
write_string_binary(column_data->get_data_at(i), buf);
356+
tmp_str->clear();
357+
if (Status st = serde->serialize_one_cell_to_json(*column_data, i, tmp_buf, opt); !st) {
358+
throw doris::Exception(ErrorCode::INTERNAL_ERROR,
359+
"Failed to serialize data for " + column_data->get_name() +
360+
" error: " + st.to_string());
361+
}
362+
tmp_buf.commit();
363+
write_string_binary(tmp_str->get_data_at(0), buf);
350364
}
365+
351366
write_var_int(max_size, buf);
352367
}
353368

354369
void read(BufferReadable& buf) {
355370
size_t size = 0;
356371
read_binary(size, buf);
372+
column_data->clear();
357373
column_data->reserve(size);
358374

359375
StringRef s;
376+
DataTypeSerDe::FormatOptions opt;
360377
for (size_t i = 0; i < size; i++) {
361378
read_string_binary(s, buf);
362-
column_data->insert_data(s.data, s.size);
379+
Slice slice(s.data, s.size);
380+
if (Status st = serde->deserialize_one_cell_from_json(*column_data, slice, opt); !st) {
381+
throw doris::Exception(ErrorCode::INTERNAL_ERROR,
382+
"Failed to deserialize data for " + column_data->get_name() +
383+
" error: " + st.to_string());
384+
}
363385
}
364386
read_var_int(max_size, buf);
365387
}
@@ -613,7 +635,6 @@ struct AggregateFunctionArrayAggData<void> {
613635
size_t size = 0;
614636
read_binary(size, buf);
615637
column_data->reserve(size);
616-
617638
StringRef s;
618639
for (size_t i = 0; i < size; i++) {
619640
read_string_binary(s, buf);

be/test/vec/aggregate_functions/agg_collect_test.cpp

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,29 @@ class VAggCollectTest : public testing::Test {
6363

6464
template <typename DataType>
6565
void agg_collect_add_elements(AggregateFunctionPtr agg_function, AggregateDataPtr place,
66-
size_t input_nums) {
66+
size_t input_nums, bool support_complex = false) {
6767
using FieldType = typename DataType::FieldType;
68-
auto type = std::make_shared<DataType>();
69-
auto input_col = type->create_column();
68+
MutableColumnPtr input_col;
69+
if (support_complex) {
70+
auto type =
71+
std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataType>()));
72+
input_col = type->create_column();
73+
} else {
74+
auto type = std::make_shared<DataType>();
75+
input_col = type->create_column();
76+
}
7077
for (size_t i = 0; i < input_nums; ++i) {
7178
for (size_t j = 0; j < _repeated_times; ++j) {
79+
if (support_complex) {
80+
if constexpr (std::is_same_v<DataType, DataTypeString>) {
81+
Array vec1 = {Field(String("item0" + std::to_string(i))),
82+
Field(String("item1" + std::to_string(i)))};
83+
input_col->insert(vec1);
84+
} else {
85+
input_col->insert_default();
86+
}
87+
continue;
88+
}
7289
if constexpr (std::is_same_v<DataType, DataTypeString>) {
7390
auto item = std::string("item") + std::to_string(i);
7491
input_col->insert_data(item.c_str(), item.size());
@@ -87,8 +104,13 @@ class VAggCollectTest : public testing::Test {
87104
}
88105

89106
template <typename DataType>
90-
void test_agg_collect(const std::string& fn_name, size_t input_nums = 0) {
107+
void test_agg_collect(const std::string& fn_name, size_t input_nums = 0,
108+
bool support_complex = false) {
91109
DataTypes data_types = {(DataTypePtr)std::make_shared<DataType>()};
110+
if (support_complex) {
111+
data_types = {
112+
(DataTypePtr)std::make_shared<DataTypeArray>(make_nullable(data_types[0]))};
113+
}
92114
LOG(INFO) << "test_agg_collect for " << fn_name << "(" << data_types[0]->get_name() << ")";
93115
AggregateFunctionSimpleFactory factory = AggregateFunctionSimpleFactory::instance();
94116
auto agg_function = factory.get(fn_name, data_types, false, -1);
@@ -98,7 +120,7 @@ class VAggCollectTest : public testing::Test {
98120
AggregateDataPtr place = memory.get();
99121
agg_function->create(place);
100122

101-
agg_collect_add_elements<DataType>(agg_function, place, input_nums);
123+
agg_collect_add_elements<DataType>(agg_function, place, input_nums, support_complex);
102124

103125
ColumnString buf;
104126
VectorBufferWriter buf_writer(buf);
@@ -111,7 +133,7 @@ class VAggCollectTest : public testing::Test {
111133
AggregateDataPtr place2 = memory2.get();
112134
agg_function->create(place2);
113135

114-
agg_collect_add_elements<DataType>(agg_function, place2, input_nums);
136+
agg_collect_add_elements<DataType>(agg_function, place2, input_nums, support_complex);
115137

116138
agg_function->merge(place, place2, &_agg_arena_pool);
117139
auto column_result = ColumnArray::create(data_types[0]->create_column());
@@ -173,4 +195,15 @@ TEST_F(VAggCollectTest, test_with_data) {
173195
test_agg_collect<DataTypeString>("collect_set", 5);
174196
}
175197

198+
TEST_F(VAggCollectTest, test_complex_data_type) {
199+
test_agg_collect<DataTypeInt8>("collect_list", 7, true);
200+
test_agg_collect<DataTypeInt128>("array_agg", 9, true);
201+
202+
test_agg_collect<DataTypeDateTime>("collect_list", 5, true);
203+
test_agg_collect<DataTypeDateTime>("array_agg", 6, true);
204+
205+
test_agg_collect<DataTypeString>("collect_list", 10, true);
206+
test_agg_collect<DataTypeString>("array_agg", 5, true);
207+
}
208+
176209
} // namespace doris::vectorized

0 commit comments

Comments
 (0)