Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions be/src/vec/functions/function_quantile_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "common/compiler_util.h" // IWYU pragma: keep
#include "common/status.h"
#include "util/quantile_state.h"
#include "util/url_coding.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column.h"
#include "vec/columns/column_complex.h"
Expand All @@ -49,9 +50,11 @@
#include "vec/data_types/data_type_nullable.h"
#include "vec/data_types/data_type_number.h"
#include "vec/data_types/data_type_quantilestate.h" // IWYU pragma: keep
#include "vec/data_types/data_type_string.h"
#include "vec/functions/function.h"
#include "vec/functions/function_const.h"
#include "vec/functions/function_helpers.h"
#include "vec/functions/function_totype.h"
#include "vec/functions/simple_function_factory.h"
#include "vec/utils/util.hpp"

Expand Down Expand Up @@ -218,10 +221,134 @@ class FunctionQuantileStatePercent : public IFunction {
}
};

class FunctionQuantileStateFromBase64 : public IFunction {
public:
static constexpr auto name = "quantile_state_from_base64";
String get_name() const override { return name; }

static FunctionPtr create() { return std::make_shared<FunctionQuantileStateFromBase64>(); }

DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
return std::make_shared<DataTypeNullable>(std::make_shared<DataTypeQuantileState>());
}

size_t get_number_of_arguments() const override { return 1; }

bool use_default_implementation_for_nulls() const override { return true; }

Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) const override {
auto res_null_map = ColumnUInt8::create(input_rows_count, 0);
auto res_data_column = ColumnQuantileState::create();
auto& null_map = res_null_map->get_data();
auto& res = res_data_column->get_data();

auto& argument_column = block.get_by_position(arguments[0]).column;
const auto& str_column = static_cast<const ColumnString&>(*argument_column);
const ColumnString::Chars& data = str_column.get_chars();
const ColumnString::Offsets& offsets = str_column.get_offsets();

res.reserve(input_rows_count);

std::string decode_buff;
int last_decode_buff_len = 0;
int curr_decode_buff_len = 0;
for (size_t i = 0; i < input_rows_count; ++i) {
const char* src_str = reinterpret_cast<const char*>(&data[offsets[i - 1]]);
int64_t src_size = offsets[i] - offsets[i - 1];

if (src_size == 0 || 0 != src_size % 4) {
res.emplace_back();
null_map[i] = 1;
continue;
}

curr_decode_buff_len = src_size + 3;
if (curr_decode_buff_len > last_decode_buff_len) {
decode_buff.resize(curr_decode_buff_len);
last_decode_buff_len = curr_decode_buff_len;
}
auto outlen = base64_decode(src_str, src_size, decode_buff.data());
if (outlen < 0) {
res.emplace_back();
null_map[i] = 1;
} else {
doris::Slice decoded_slice(decode_buff.data(), outlen);
doris::QuantileState quantile_state;
if (!quantile_state.deserialize(decoded_slice)) {
return Status::RuntimeError(fmt::format(
"quantile_state_from_base64 decode failed: base64: {}", src_str));
} else {
res.emplace_back(std::move(quantile_state));
}
}
}

block.get_by_position(result).column =
ColumnNullable::create(std::move(res_data_column), std::move(res_null_map));
return Status::OK();
}
};

struct NameQuantileStateToBase64 {
static constexpr auto name = "quantile_state_to_base64";
};

struct QuantileStateToBase64 {
using ReturnType = DataTypeString;
static constexpr auto TYPE_INDEX = TypeIndex::QuantileState;
using Type = DataTypeQuantileState::FieldType;
using ReturnColumnType = ColumnString;
using Chars = ColumnString::Chars;
using Offsets = ColumnString::Offsets;

static Status vector(const std::vector<QuantileState>& data, Chars& chars, Offsets& offsets) {
size_t size = data.size();
offsets.resize(size);
size_t output_char_size = 0;
for (size_t i = 0; i < size; ++i) {
auto& quantile_state_val = const_cast<QuantileState&>(data[i]);
auto ser_size = quantile_state_val.get_serialized_size();
output_char_size += (int)(4.0 * ceil((double)ser_size / 3.0));
}
ColumnString::check_chars_length(output_char_size, size);
chars.resize(output_char_size);
auto* chars_data = chars.data();

size_t cur_ser_size = 0;
size_t last_ser_size = 0;
std::string ser_buff;
size_t encoded_offset = 0;
for (size_t i = 0; i < size; ++i) {
auto& quantile_state_val = const_cast<QuantileState&>(data[i]);

cur_ser_size = quantile_state_val.get_serialized_size();
if (cur_ser_size > last_ser_size) {
last_ser_size = cur_ser_size;
ser_buff.resize(cur_ser_size);
}
size_t real_size =
quantile_state_val.serialize(reinterpret_cast<uint8_t*>(ser_buff.data()));
auto outlen = base64_encode((const unsigned char*)ser_buff.data(), real_size,
chars_data + encoded_offset);
DCHECK(outlen > 0);

encoded_offset += outlen;
offsets[i] = encoded_offset;
}
return Status::OK();
}
};

using FunctionQuantileStateToBase64 =
FunctionUnaryToType<QuantileStateToBase64, NameQuantileStateToBase64>;

void register_function_quantile_state(SimpleFunctionFactory& factory) {
factory.register_function<FunctionConst<QuantileStateEmpty, false>>();
factory.register_function<FunctionQuantileStatePercent>();
factory.register_function<FunctionToQuantileState>();
factory.register_function<FunctionQuantileStateFromBase64>();
factory.register_function<FunctionQuantileStateToBase64>();
}

} // namespace doris::vectorized
215 changes: 215 additions & 0 deletions be/test/vec/function/function_quantile_state_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <gtest/gtest.h>

#include <string>

#include "function_test_util.h"
#include "util/quantile_state.h"
#include "util/url_coding.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type_quantilestate.h"
#include "vec/data_types/data_type_string.h"

namespace doris::vectorized {

TEST(function_quantile_state_test, function_quantile_state_to_base64) {
std::string func_name = "quantile_state_to_base64";
InputTypeSet input_types = {TypeIndex::QuantileState};

QuantileState empty_quantile_state;

QuantileState single_quantile_state;
single_quantile_state.add_value(1.0);

QuantileState multi_quantile_state;
multi_quantile_state.add_value(1.0);
multi_quantile_state.add_value(2.0);
multi_quantile_state.add_value(3.0);
multi_quantile_state.add_value(4.0);
multi_quantile_state.add_value(5.0);

QuantileState explicit_quantile_state;
for (int i = 0; i < 100; i++) {
explicit_quantile_state.add_value(static_cast<double>(i));
}

QuantileState tdigest_quantile_state;
for (int i = 0; i < 3000; i++) {
tdigest_quantile_state.add_value(static_cast<double>(i));
}

uint8_t buf[65536];
unsigned char encoded_buf[131072];

std::string empty_base64;
{
size_t len = empty_quantile_state.serialize(buf);
size_t encoded_len = base64_encode(buf, len, encoded_buf);
empty_base64 = std::string(reinterpret_cast<char*>(encoded_buf), encoded_len);
}

std::string single_base64;
{
size_t len = single_quantile_state.serialize(buf);
size_t encoded_len = base64_encode(buf, len, encoded_buf);
single_base64 = std::string(reinterpret_cast<char*>(encoded_buf), encoded_len);
}

std::string multi_base64;
{
size_t len = multi_quantile_state.serialize(buf);
size_t encoded_len = base64_encode(buf, len, encoded_buf);
multi_base64 = std::string(reinterpret_cast<char*>(encoded_buf), encoded_len);
}

std::string explicit_base64;
{
size_t len = explicit_quantile_state.serialize(buf);
size_t encoded_len = base64_encode(buf, len, encoded_buf);
explicit_base64 = std::string(reinterpret_cast<char*>(encoded_buf), encoded_len);
}

std::string tdigest_base64;
{
size_t len = tdigest_quantile_state.serialize(buf);
size_t encoded_len = base64_encode(buf, len, encoded_buf);
tdigest_base64 = std::string(reinterpret_cast<char*>(encoded_buf), encoded_len);
}

{
DataSet data_set = {{{&empty_quantile_state}, empty_base64},
{{&single_quantile_state}, single_base64},
{{&multi_quantile_state}, multi_base64},
{{&explicit_quantile_state}, explicit_base64},
{{&tdigest_quantile_state}, tdigest_base64}};

static_cast<void>(check_function<DataTypeString, true>(func_name, input_types, data_set));
}
}

TEST(function_quantile_state_test, function_quantile_state_from_base64) {
std::string func_name = "quantile_state_from_base64";
InputTypeSet input_types = {TypeIndex::String};

// Create quantile states for comparison
QuantileState empty_quantile_state;

QuantileState single_quantile_state;
single_quantile_state.add_value(1.0);

QuantileState multi_quantile_state;
multi_quantile_state.add_value(1.0);
multi_quantile_state.add_value(2.0);
multi_quantile_state.add_value(3.0);
multi_quantile_state.add_value(4.0);
multi_quantile_state.add_value(5.0);

uint8_t buf[65536];
unsigned char encoded_buf[131072];
std::string empty_base64;
std::string single_base64;
std::string multi_base64;

{
size_t len = empty_quantile_state.serialize(buf);
size_t encoded_len = base64_encode(buf, len, encoded_buf);
empty_base64 = std::string(reinterpret_cast<char*>(encoded_buf), encoded_len);
}

{
size_t len = single_quantile_state.serialize(buf);
size_t encoded_len = base64_encode(buf, len, encoded_buf);
single_base64 = std::string(reinterpret_cast<char*>(encoded_buf), encoded_len);
}

{
size_t len = multi_quantile_state.serialize(buf);
size_t encoded_len = base64_encode(buf, len, encoded_buf);
multi_base64 = std::string(reinterpret_cast<char*>(encoded_buf), encoded_len);
}

{
char decoded_buf[65536];
int decoded_len = base64_decode(empty_base64.c_str(), empty_base64.length(), decoded_buf);
EXPECT_GT(decoded_len, 0);

QuantileState decoded_empty;
doris::Slice slice(decoded_buf, decoded_len);
EXPECT_TRUE(decoded_empty.deserialize(slice));

EXPECT_TRUE(std::isnan(empty_quantile_state.get_value_by_percentile(0.5)));
EXPECT_TRUE(std::isnan(decoded_empty.get_value_by_percentile(0.5)));
}

{
char decoded_buf[65536];
int decoded_len = base64_decode(single_base64.c_str(), single_base64.length(), decoded_buf);
EXPECT_GT(decoded_len, 0);

QuantileState decoded_single;
doris::Slice slice(decoded_buf, decoded_len);
EXPECT_TRUE(decoded_single.deserialize(slice));

EXPECT_NEAR(single_quantile_state.get_value_by_percentile(0.5),
decoded_single.get_value_by_percentile(0.5), 0.01);
}

{
char decoded_buf[65536];
int decoded_len = base64_decode(multi_base64.c_str(), multi_base64.length(), decoded_buf);
EXPECT_GT(decoded_len, 0);

QuantileState decoded_multi;
doris::Slice slice(decoded_buf, decoded_len);
EXPECT_TRUE(decoded_multi.deserialize(slice));

EXPECT_NEAR(multi_quantile_state.get_value_by_percentile(0.5),
decoded_multi.get_value_by_percentile(0.5), 0.01);
EXPECT_NEAR(multi_quantile_state.get_value_by_percentile(0.9),
decoded_multi.get_value_by_percentile(0.9), 0.01);
}
}

TEST(function_quantile_state_test, function_quantile_state_roundtrip) {
QuantileState original;
for (int i = 0; i < 50; i++) {
original.add_value(static_cast<double>(i * 2));
}

uint8_t ser_buf[65536];
size_t ser_len = original.serialize(ser_buf);

unsigned char encoded_buf[131072];
size_t encoded_len = base64_encode(ser_buf, ser_len, encoded_buf);
std::string base64_str(reinterpret_cast<char*>(encoded_buf), encoded_len);

char decoded_buf[65536];
int decoded_len = base64_decode(base64_str.c_str(), base64_str.length(), decoded_buf);
EXPECT_GT(decoded_len, 0);

QuantileState recovered;
doris::Slice slice(decoded_buf, decoded_len);
EXPECT_TRUE(recovered.deserialize(slice));

EXPECT_NEAR(original.get_value_by_percentile(0.5), recovered.get_value_by_percentile(0.5),
0.01);
EXPECT_NEAR(original.get_value_by_percentile(0.9), recovered.get_value_by_percentile(0.9),
0.01);
}

} // namespace doris::vectorized
Loading