Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/action-pr-title
Submodule action-pr-title updated 548 files
2 changes: 1 addition & 1 deletion .github/actions/get-workflow-origin
92 changes: 91 additions & 1 deletion be/src/vec/functions/function_string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,95 @@ struct StringAppendTrailingCharIfAbsent {
}
};

struct HammingDistanceImpl {
static constexpr auto name = "hamming_distance";
using Chars = ColumnString::Chars;
using Offsets = ColumnString::Offsets;
using ReturnType = DataTypeInt64;
using ColumnType = ColumnInt64;

// Calculate Hamming distance between two strings of equal length
static Int64 calculate_hamming_distance(const StringRef& str1, const StringRef& str2) {
DCHECK_EQ(str1.size, str2.size);
Int64 distance = 0;
for (size_t i = 0; i < str1.size; ++i) {
if (str1.data[i] != str2.data[i]) {
++distance;
}
}
return distance;
}

// vector_vector: both arguments are columns
static void vector_vector(FunctionContext* context, const Chars& ldata, const Offsets& loffsets,
const Chars& rdata, const Offsets& roffsets,
PaddedPODArray<Int64>& res, NullMap& null_map_data) {
DCHECK_EQ(loffsets.size(), roffsets.size());
size_t input_rows_count = loffsets.size();
res.resize(input_rows_count);

for (size_t i = 0; i < input_rows_count; ++i) {
StringRef lstr = StringRef(reinterpret_cast<const char*>(&ldata[loffsets[i - 1]]),
loffsets[i] - loffsets[i - 1]);
StringRef rstr = StringRef(reinterpret_cast<const char*>(&rdata[roffsets[i - 1]]),
roffsets[i] - roffsets[i - 1]);

// Throw an error if strings have different lengths (enforce contract).
if (lstr.size != rstr.size) {
throw doris::Exception(
ErrorCode::INVALID_ARGUMENT,
"hamming_distance: input strings must have equal length, got {} and {}",
lstr.size, rstr.size);
}
res[i] = calculate_hamming_distance(lstr, rstr);
}
}

// vector_scalar: first argument is column, second is constant
static void vector_scalar(FunctionContext* context, const Chars& ldata, const Offsets& loffsets,
const StringRef& rstr, PaddedPODArray<Int64>& res,
NullMap& null_map_data) {
size_t input_rows_count = loffsets.size();
res.resize(input_rows_count);

for (size_t i = 0; i < input_rows_count; ++i) {
StringRef lstr = StringRef(reinterpret_cast<const char*>(&ldata[loffsets[i - 1]]),
loffsets[i] - loffsets[i - 1]);

// Throw an error if strings have different lengths (enforce contract).
if (lstr.size != rstr.size) {
throw doris::Exception(
ErrorCode::INVALID_ARGUMENT,
"hamming_distance: input strings must have equal length, got {} and {}",
lstr.size, rstr.size);
}
res[i] = calculate_hamming_distance(lstr, rstr);
}
}

// scalar_vector: first argument is constant, second is column
static void scalar_vector(FunctionContext* context, const StringRef& lstr, const Chars& rdata,
const Offsets& roffsets, PaddedPODArray<Int64>& res,
NullMap& null_map_data) {
size_t input_rows_count = roffsets.size();
res.resize(input_rows_count);

for (size_t i = 0; i < input_rows_count; ++i) {
StringRef rstr = StringRef(reinterpret_cast<const char*>(&rdata[roffsets[i - 1]]),
roffsets[i] - roffsets[i - 1]);

// Throw an error if strings have different lengths (enforce contract).
if (lstr.size != rstr.size) {
throw doris::Exception(
ErrorCode::INVALID_ARGUMENT,
"hamming_distance: input strings must have equal length, got {} and {}",
lstr.size, rstr.size);
}
res[i] = calculate_hamming_distance(lstr, rstr);
}
}
};

struct StringLPad {
static constexpr auto name = "lpad";
static constexpr auto is_lpad = true;
Expand Down Expand Up @@ -1342,7 +1431,7 @@ using FunctionFromBase64 = FunctionStringOperateToNullType<FromBase64Impl>;

using FunctionStringAppendTrailingCharIfAbsent =
FunctionBinaryStringOperateToNullType<StringAppendTrailingCharIfAbsent>;

using FunctionHammingDistance = FunctionBinaryStringOperateToNullType<HammingDistanceImpl>;
using FunctionStringLPad = FunctionStringPad<StringLPad>;
using FunctionStringRPad = FunctionStringPad<StringRPad>;

Expand Down Expand Up @@ -1440,6 +1529,7 @@ void register_function_string(SimpleFunctionFactory& factory) {
factory.register_function<FunctionSubReplace<SubReplaceFourImpl>>();
factory.register_function<FunctionOverlay>();
factory.register_function<FunctionStrcmp>();
factory.register_function<FunctionHammingDistance>();
factory.register_function<FunctionNgramSearch>();
factory.register_function<FunctionXPathString>();
factory.register_function<FunctionCrc32Internal>();
Expand Down
102 changes: 20 additions & 82 deletions be/test/vec/function/function_string_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3778,95 +3778,33 @@ TEST(function_string_test, function_sha1_test) {
}
}

TEST(function_string_test, function_unicode_normalize_nfc_basic) {
std::string func_name = "unicode_normalize";

InputTypeSet input_types = {
PrimitiveType::TYPE_VARCHAR,
Consted {PrimitiveType::TYPE_VARCHAR},
};

std::string cafe_decomposed = std::string("Cafe\xCC\x81");
std::string cafe_composed = std::string("Caf\xC3\xA9");
TEST(function_string_test, function_hamming_distance_test) {
std::string func_name = "hamming_distance";

{
DataSet data_set = {
{{cafe_decomposed, std::string("NFC")}, cafe_composed},
};
static_cast<void>(check_function<DataTypeString, true>(func_name, input_types, data_set));
}

{
DataSet data_set = {
{{cafe_composed, std::string("NFC")}, cafe_composed},
};
static_cast<void>(check_function<DataTypeString, true>(func_name, input_types, data_set));
}
}

TEST(function_string_test, function_unicode_normalize_modes_and_trim) {
std::string func_name = "unicode_normalize";

InputTypeSet input_types = {
PrimitiveType::TYPE_VARCHAR,
Consted {PrimitiveType::TYPE_VARCHAR},
};

std::string cafe_decomposed = std::string("Cafe\xCC\x81");
std::string cafe_composed = std::string("Caf\xC3\xA9");

{
DataSet data_set = {
{{cafe_composed, std::string(" nFd ")}, cafe_decomposed},
};
static_cast<void>(check_function<DataTypeString, true>(func_name, input_types, data_set));
}
InputTypeSet input_types = {PrimitiveType::TYPE_VARCHAR, PrimitiveType::TYPE_VARCHAR};

{
DataSet data_set = {
{{std::string("ABC 123"), std::string(" nfkc_cf ")}, std::string("abc 123")},
// Same strings - distance 0
{{std::string("abc"), std::string("abc")}, std::int64_t(0)},
{{std::string(""), std::string("")}, std::int64_t(0)},
{{std::string("hello"), std::string("hello")}, std::int64_t(0)},

// Different strings - distance > 0
{{std::string("abc"), std::string("axc")}, std::int64_t(1)},
{{std::string("abc"), std::string("xyz")}, std::int64_t(3)},
{{std::string("hello"), std::string("hallo")}, std::int64_t(1)},
{{std::string("test"), std::string("text")}, std::int64_t(1)},
{{std::string("abcd"), std::string("abed")}, std::int64_t(1)},

// NULL inputs
{{Null(), std::string("abc")}, Null()},
{{std::string("abc"), Null()}, Null()},
{{Null(), Null()}, Null()},
};
static_cast<void>(check_function<DataTypeString, true>(func_name, input_types, data_set));
}

{
DataSet data_set = {
{{std::string("plain-ascii"), std::string("NFKD")}, std::string("plain-ascii")},
};
static_cast<void>(check_function<DataTypeString, true>(func_name, input_types, data_set));
check_function_all_arg_comb<DataTypeInt64, true>(func_name, input_types, data_set);
}
}

TEST(function_string_test, function_unicode_normalize_mode_not_const) {
std::string func_name = "unicode_normalize";

InputTypeSet input_types = {
PrimitiveType::TYPE_VARCHAR,
PrimitiveType::TYPE_VARCHAR,
};

DataSet data_set = {
{{std::string("abc"), std::string("NFC")}, std::string("abc")},
};

Status st = check_function<DataTypeString, true>(func_name, input_types, data_set);
EXPECT_NE(Status::OK(), st);
}

TEST(function_string_test, function_unicode_normalize_invalid_mode) {
std::string func_name = "unicode_normalize";

InputTypeSet input_types = {
PrimitiveType::TYPE_VARCHAR,
Consted {PrimitiveType::TYPE_VARCHAR},
};

DataSet data_set = {
{{std::string("abc"), std::string("INVALID_MODE")}, std::string("abc")},
};

Status st = check_function<DataTypeString, true>(func_name, input_types, data_set);
EXPECT_NE(Status::OK(), st);
}

} // namespace doris::vectorized
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.VarcharType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.List;

/**
* ScalarFunction 'hamming_distance'.
*/
public class HammingDistance extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature, AlwaysNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE)
.args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(BigIntType.INSTANCE)
.args(StringType.INSTANCE, StringType.INSTANCE));

/**
* constructor with 2 arguments.
*/
public HammingDistance(Expression arg0, Expression arg1) {
super("hamming_distance", arg0, arg1);
}

/** constructor for withChildren and reuse signature */
private HammingDistance(ScalarFunctionParams functionParams) {
super(functionParams);
}

/**
* withChildren.
*/
@Override
public HammingDistance withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 2);
return new HammingDistance(getFunctionParams(children));
}

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitHammingDistance(this, context);
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.TryCast;
import org.apache.doris.nereids.trees.expressions.UnaryArithmetic;
import org.apache.doris.nereids.trees.expressions.UnaryOperator;
Expand All @@ -90,6 +91,7 @@
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HammingDistance;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction;
import org.apache.doris.nereids.trees.expressions.functions.table.TableValuedFunction;
Expand Down Expand Up @@ -147,6 +149,18 @@ public R visitScalarFunction(ScalarFunction scalarFunction, C context) {
return visitBoundFunction(scalarFunction, context);
}

public R visitToSeconds(org.apache.doris.nereids.trees.expressions.functions.scalar.ToSeconds toSeconds, C context) {
return visitScalarFunction(toSeconds, context);
}

public R visitUnicodeNormalize(org.apache.doris.nereids.trees.expressions.functions.scalar.UnicodeNormalize unicodeNormalize, C context) {
return visitScalarFunction(unicodeNormalize, context);
}

public R visitHammingDistance(org.apache.doris.nereids.trees.expressions.functions.scalar.HammingDistance hammingDistance, C context) {
return visitScalarFunction(hammingDistance, context);
}

public R visitSearchExpression(SearchExpression searchExpression, C context) {
return visit(searchExpression, context);
}
Expand Down Expand Up @@ -455,6 +469,10 @@ public R visitSubqueryExpr(SubqueryExpr subqueryExpr, C context) {
return visit(subqueryExpr, context);
}

public R visitTimestampArithmetic(TimestampArithmetic arithmetic, C context) {
return visit(arithmetic, context);
}

public R visitScalarSubquery(ScalarSubquery scalar, C context) {
return visitSubqueryExpr(scalar, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.GetFormat;
import org.apache.doris.nereids.trees.expressions.functions.scalar.GetVariantType;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Greatest;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HammingDistance;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ToSeconds;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Hex;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HllCardinality;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HllEmpty;
Expand Down Expand Up @@ -2505,6 +2507,10 @@ default R visitStrcmp(Strcmp strcmp, C context) {
return visitScalarFunction(strcmp, context);
}

default R visitHammingDistance(HammingDistance hammingDistance, C context) {
return visitScalarFunction(hammingDistance, context);
}

default R visitStripNullValue(StripNullValue stripNullValue, C context) {
return visitScalarFunction(stripNullValue, context);
}
Expand Down
Loading