|
| 1 | +#include <DataTypes/DataTypeString.h> |
| 2 | +#include <DataTypes/DataTypeArray.h> |
| 3 | +#include <Columns/ColumnString.h> |
| 4 | +#include <Columns/ColumnFixedString.h> |
| 5 | +#include <Columns/ColumnArray.h> |
| 6 | +#include <Interpreters/Context_fwd.h> |
| 7 | +#include <Interpreters/ITokenExtractor.h> |
| 8 | +#include <Functions/IFunction.h> |
| 9 | +#include <Functions/FunctionHelpers.h> |
| 10 | +#include <Functions/FunctionFactory.h> |
| 11 | + |
| 12 | + |
| 13 | +namespace DB |
| 14 | +{ |
| 15 | + |
| 16 | +namespace ErrorCodes |
| 17 | +{ |
| 18 | + extern const int BAD_ARGUMENTS; |
| 19 | +} |
| 20 | + |
| 21 | +class FunctionNgrams : public IFunction |
| 22 | +{ |
| 23 | +public: |
| 24 | + |
| 25 | + static constexpr auto name = "ngrams"; |
| 26 | + |
| 27 | + static FunctionPtr create(ContextPtr) |
| 28 | + { |
| 29 | + return std::make_shared<FunctionNgrams>(); |
| 30 | + } |
| 31 | + |
| 32 | + String getName() const override { return name; } |
| 33 | + |
| 34 | + size_t getNumberOfArguments() const override { return 2; } |
| 35 | + bool isVariadic() const override { return false; } |
| 36 | + ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return ColumnNumbers{1}; } |
| 37 | + |
| 38 | + bool useDefaultImplementationForNulls() const override { return true; } |
| 39 | + bool useDefaultImplementationForConstants() const override { return true; } |
| 40 | + bool useDefaultImplementationForLowCardinalityColumns() const override { return true; } |
| 41 | + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } |
| 42 | + |
| 43 | + DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override |
| 44 | + { |
| 45 | + auto ngram_input_argument_type = WhichDataType(arguments[0].type); |
| 46 | + if (!ngram_input_argument_type.isStringOrFixedString()) |
| 47 | + throw Exception(ErrorCodes::BAD_ARGUMENTS, |
| 48 | + "Function {} first argument type should be String or FixedString. Actual {}", |
| 49 | + getName(), |
| 50 | + arguments[0].type->getName()); |
| 51 | + |
| 52 | + const auto & column_with_type = arguments[1]; |
| 53 | + const auto & ngram_argument_column = arguments[1].column; |
| 54 | + auto ngram_argument_type = WhichDataType(column_with_type.type); |
| 55 | + |
| 56 | + if (!ngram_argument_type.isNativeUInt() || !ngram_argument_column || !isColumnConst(*ngram_argument_column)) |
| 57 | + throw Exception(ErrorCodes::BAD_ARGUMENTS, |
| 58 | + "Function {} second argument type should be constant UInt. Actual {}", |
| 59 | + getName(), |
| 60 | + arguments[1].type->getName()); |
| 61 | + |
| 62 | + return std::make_shared<DataTypeArray>(std::make_shared<DataTypeString>()); |
| 63 | + } |
| 64 | + |
| 65 | + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override |
| 66 | + { |
| 67 | + auto column_offsets = ColumnArray::ColumnOffsets::create(); |
| 68 | + |
| 69 | + Field ngram_argument_value; |
| 70 | + arguments[1].column->get(0, ngram_argument_value); |
| 71 | + auto ngram_value = ngram_argument_value.safeGet<UInt64>(); |
| 72 | + |
| 73 | + NgramTokenExtractor extractor(ngram_value); |
| 74 | + |
| 75 | + auto result_column_string = ColumnString::create(); |
| 76 | + |
| 77 | + auto input_column = arguments[0].column; |
| 78 | + |
| 79 | + if (const auto * column_string = checkAndGetColumn<ColumnString>(input_column.get())) |
| 80 | + executeImpl(extractor, *column_string, *result_column_string, *column_offsets, input_rows_count); |
| 81 | + else if (const auto * column_fixed_string = checkAndGetColumn<ColumnFixedString>(input_column.get())) |
| 82 | + executeImpl(extractor, *column_fixed_string, *result_column_string, *column_offsets, input_rows_count); |
| 83 | + |
| 84 | + return ColumnArray::create(std::move(result_column_string), std::move(column_offsets)); |
| 85 | + } |
| 86 | + |
| 87 | +private: |
| 88 | + |
| 89 | + template <typename ExtractorType, typename StringColumnType, typename ResultStringColumnType> |
| 90 | + void executeImpl( |
| 91 | + const ExtractorType & extractor, |
| 92 | + StringColumnType & input_data_column, |
| 93 | + ResultStringColumnType & result_data_column, |
| 94 | + ColumnArray::ColumnOffsets & offsets_column, |
| 95 | + size_t input_rows_count) const |
| 96 | + { |
| 97 | + size_t current_tokens_size = 0; |
| 98 | + auto & offsets_data = offsets_column.getData(); |
| 99 | + |
| 100 | + offsets_data.resize(input_rows_count); |
| 101 | + |
| 102 | + for (size_t i = 0; i < input_rows_count; ++i) |
| 103 | + { |
| 104 | + auto data = input_data_column.getDataAt(i); |
| 105 | + |
| 106 | + size_t cur = 0; |
| 107 | + size_t token_start = 0; |
| 108 | + size_t token_length = 0; |
| 109 | + |
| 110 | + while (cur < data.size && extractor.nextInString(data.data, data.size, &cur, &token_start, &token_length)) |
| 111 | + { |
| 112 | + result_data_column.insertData(data.data + token_start, token_length); |
| 113 | + ++current_tokens_size; |
| 114 | + } |
| 115 | + |
| 116 | + offsets_data[i] = current_tokens_size; |
| 117 | + } |
| 118 | + } |
| 119 | +}; |
| 120 | + |
| 121 | +REGISTER_FUNCTION(Ngrams) |
| 122 | +{ |
| 123 | + factory.registerFunction<FunctionNgrams>(FunctionDocumentation{ |
| 124 | + .description = "Splits a UTF-8 string into n-grams symbols.", |
| 125 | + .category = FunctionDocumentation::Category::StringSplitting}); |
| 126 | +} |
| 127 | + |
| 128 | +} |
0 commit comments