Skip to content

Commit b32c9d6

Browse files
Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507) (#9514)
close #9473 Support other integer types for SubstringUTF8 & RightUTF8 functions Signed-off-by: JaySon-Huang <tshent@qq.com> Co-authored-by: JaySon <tshent@qq.com> Co-authored-by: JaySon-Huang <tshent@qq.com>
1 parent a95bef8 commit b32c9d6

File tree

6 files changed

+403
-157
lines changed

6 files changed

+403
-157
lines changed

dbms/src/Functions/FunctionsString.cpp

Lines changed: 120 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1682,26 +1682,41 @@ class FunctionSubstringUTF8 : public IFunction
16821682
bool is_start_type_valid
16831683
= getNumberType(block.getByPosition(arguments[1]).type, [&](const auto & start_type, bool) {
16841684
using StartType = std::decay_t<decltype(start_type)>;
1685-
// Int64 / UInt64
16861685
using StartFieldType = typename StartType::FieldType;
1686+
const ColumnVector<StartFieldType> * column_vector_start
1687+
= getInnerColumnVector<StartFieldType>(column_start);
1688+
if unlikely (!column_vector_start)
1689+
throw Exception(
1690+
fmt::format(
1691+
"Illegal type {} of argument 2 of function {}",
1692+
block.getByPosition(arguments[1]).type->getName(),
1693+
getName()),
1694+
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
16871695

16881696
// vector const const
16891697
if (!column_string->isColumnConst() && column_start->isColumnConst()
16901698
&& (implicit_length || block.getByPosition(arguments[2]).column->isColumnConst()))
16911699
{
1692-
auto [is_positive, start_abs]
1693-
= getValueFromStartField<StartFieldType>((*block.getByPosition(arguments[1]).column)[0]);
1700+
auto [is_positive, start_abs] = getValueFromStartColumn<StartFieldType>(*column_vector_start, 0);
16941701
UInt64 length = 0;
16951702
if (!implicit_length)
16961703
{
16971704
bool is_length_type_valid = getNumberType(
16981705
block.getByPosition(arguments[2]).type,
16991706
[&](const auto & length_type, bool) {
17001707
using LengthType = std::decay_t<decltype(length_type)>;
1701-
// Int64 / UInt64
17021708
using LengthFieldType = typename LengthType::FieldType;
1703-
length = getValueFromLengthField<LengthFieldType>(
1704-
(*block.getByPosition(arguments[2]).column)[0]);
1709+
const ColumnVector<LengthFieldType> * column_vector_length
1710+
= getInnerColumnVector<LengthFieldType>(block.getByPosition(arguments[2]).column);
1711+
if unlikely (!column_vector_length)
1712+
throw Exception(
1713+
fmt::format(
1714+
"Illegal type {} of argument 3 of function {}",
1715+
block.getByPosition(arguments[2]).type->getName(),
1716+
getName()),
1717+
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
1718+
1719+
length = getValueFromLengthColumn<LengthFieldType>(*column_vector_length, 0);
17051720
return true;
17061721
});
17071722

@@ -1736,15 +1751,15 @@ class FunctionSubstringUTF8 : public IFunction
17361751
if (column_start->isColumnConst())
17371752
{
17381753
// func always return const value
1739-
auto start_const = getValueFromStartField<StartFieldType>((*column_start)[0]);
1754+
auto start_const = getValueFromStartColumn<StartFieldType>(*column_vector_start, 0);
17401755
get_start_func = [start_const](size_t) {
17411756
return start_const;
17421757
};
17431758
}
17441759
else
17451760
{
1746-
get_start_func = [&column_start](size_t i) {
1747-
return getValueFromStartField<StartFieldType>((*column_start)[i]);
1761+
get_start_func = [column_vector_start](size_t i) {
1762+
return getValueFromStartColumn<StartFieldType>(*column_vector_start, i);
17481763
};
17491764
}
17501765

@@ -1757,26 +1772,36 @@ class FunctionSubstringUTF8 : public IFunction
17571772
block.getByPosition(arguments[2]).type,
17581773
[&](const auto & length_type, bool) {
17591774
using LengthType = std::decay_t<decltype(length_type)>;
1760-
// Int64 / UInt64
17611775
using LengthFieldType = typename LengthType::FieldType;
1776+
const ColumnVector<LengthFieldType> * column_vector_length
1777+
= getInnerColumnVector<LengthFieldType>(column_length);
1778+
if unlikely (!column_vector_length)
1779+
throw Exception(
1780+
fmt::format(
1781+
"Illegal type {} of argument 3 of function {}",
1782+
block.getByPosition(arguments[2]).type->getName(),
1783+
getName()),
1784+
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
1785+
17621786
if (column_length->isColumnConst())
17631787
{
17641788
// func always return const value
1765-
auto length_const = getValueFromLengthField<LengthFieldType>((*column_length)[0]);
1789+
auto length_const
1790+
= getValueFromLengthColumn<LengthFieldType>(*column_vector_length, 0);
17661791
get_length_func = [length_const](size_t) {
17671792
return length_const;
17681793
};
17691794
}
17701795
else
17711796
{
1772-
get_length_func = [column_length](size_t i) {
1773-
return getValueFromLengthField<LengthFieldType>((*column_length)[i]);
1797+
get_length_func = [column_vector_length](size_t i) {
1798+
return getValueFromLengthColumn<LengthFieldType>(*column_vector_length, i);
17741799
};
17751800
}
17761801
return true;
17771802
});
17781803

1779-
if (!is_length_type_valid)
1804+
if unlikely (!is_length_type_valid)
17801805
throw Exception(
17811806
fmt::format("3nd argument of function {} must have UInt/Int type.", getName()));
17821807
}
@@ -1814,10 +1839,38 @@ class FunctionSubstringUTF8 : public IFunction
18141839
return true;
18151840
});
18161841

1817-
if (!is_start_type_valid)
1842+
if unlikely (!is_start_type_valid)
18181843
throw Exception(fmt::format("2nd argument of function {} must have UInt/Int type.", getName()));
18191844
}
18201845

1846+
template <typename Integer>
1847+
static const ColumnVector<Integer> * getInnerColumnVector(const ColumnPtr & column)
1848+
{
1849+
if (column->isColumnConst())
1850+
return checkAndGetColumn<ColumnVector<Integer>>(
1851+
checkAndGetColumn<ColumnConst>(column.get())->getDataColumnPtr().get());
1852+
return checkAndGetColumn<ColumnVector<Integer>>(column.get());
1853+
}
1854+
1855+
template <typename Integer>
1856+
static size_t getValueFromLengthColumn(const ColumnVector<Integer> & column, size_t index)
1857+
{
1858+
Integer val = column.getElement(index);
1859+
if constexpr (
1860+
std::is_same_v<Integer, Int8> || std::is_same_v<Integer, Int16> || std::is_same_v<Integer, Int32>
1861+
|| std::is_same_v<Integer, Int64>)
1862+
{
1863+
return val < 0 ? 0 : val;
1864+
}
1865+
else
1866+
{
1867+
static_assert(
1868+
std::is_same_v<Integer, UInt8> || std::is_same_v<Integer, UInt16> || std::is_same_v<Integer, UInt32>
1869+
|| std::is_same_v<Integer, UInt64>);
1870+
return val;
1871+
}
1872+
}
1873+
18211874
private:
18221875
using VectorConstConstFunc = std::function<void(
18231876
const ColumnString::Chars_t &,
@@ -1841,49 +1894,40 @@ class FunctionSubstringUTF8 : public IFunction
18411894
}
18421895
}
18431896

1844-
template <typename Integer>
1845-
static size_t getValueFromLengthField(const Field & length_field)
1846-
{
1847-
if constexpr (std::is_same_v<Integer, Int64>)
1848-
{
1849-
Int64 signed_length = length_field.get<Int64>();
1850-
return signed_length < 0 ? 0 : signed_length;
1851-
}
1852-
else
1853-
{
1854-
static_assert(std::is_same_v<Integer, UInt64>);
1855-
return length_field.get<UInt64>();
1856-
}
1857-
}
1858-
18591897
// return {is_positive, abs}
18601898
template <typename Integer>
1861-
static std::pair<bool, size_t> getValueFromStartField(const Field & start_field)
1899+
static std::pair<bool, size_t> getValueFromStartColumn(const ColumnVector<Integer> & column, size_t index)
18621900
{
1863-
if constexpr (std::is_same_v<Integer, Int64>)
1901+
Integer val = column.getElement(index);
1902+
if constexpr (
1903+
std::is_same_v<Integer, Int8> || std::is_same_v<Integer, Int16> || std::is_same_v<Integer, Int32>
1904+
|| std::is_same_v<Integer, Int64>)
18641905
{
1865-
Int64 signed_length = start_field.get<Int64>();
1866-
1867-
if (signed_length < 0)
1868-
{
1869-
return {false, static_cast<size_t>(-signed_length)};
1870-
}
1871-
else
1872-
{
1873-
return {true, static_cast<size_t>(signed_length)};
1874-
}
1906+
if (val < 0)
1907+
return {false, static_cast<size_t>(-val)};
1908+
return {true, static_cast<size_t>(val)};
18751909
}
18761910
else
18771911
{
1878-
static_assert(std::is_same_v<Integer, UInt64>);
1879-
return {true, start_field.get<UInt64>()};
1912+
static_assert(
1913+
std::is_same_v<Integer, UInt8> || std::is_same_v<Integer, UInt16> || std::is_same_v<Integer, UInt32>
1914+
|| std::is_same_v<Integer, UInt64>);
1915+
return {true, val};
18801916
}
18811917
}
18821918

18831919
template <typename F>
18841920
static bool getNumberType(DataTypePtr type, F && f)
18851921
{
1886-
return castTypeToEither<DataTypeInt64, DataTypeUInt64>(type.get(), std::forward<F>(f));
1922+
return castTypeToEither<
1923+
DataTypeUInt8,
1924+
DataTypeUInt16,
1925+
DataTypeUInt32,
1926+
DataTypeUInt64,
1927+
DataTypeInt8,
1928+
DataTypeInt16,
1929+
DataTypeInt32,
1930+
DataTypeInt64>(type.get(), std::forward<F>(f));
18871931
}
18881932
};
18891933

@@ -1922,16 +1966,28 @@ class FunctionRightUTF8 : public IFunction
19221966
bool is_length_type_valid
19231967
= getLengthType(block.getByPosition(arguments[1]).type, [&](const auto & length_type, bool) {
19241968
using LengthType = std::decay_t<decltype(length_type)>;
1925-
// Int64 / UInt64
19261969
using LengthFieldType = typename LengthType::FieldType;
19271970

1971+
const ColumnVector<LengthFieldType> * column_vector_length
1972+
= FunctionSubstringUTF8::getInnerColumnVector<LengthFieldType>(column_length);
1973+
if unlikely (!column_vector_length)
1974+
throw Exception(
1975+
fmt::format(
1976+
"Illegal type {} of argument 2 of function {}",
1977+
block.getByPosition(arguments[1]).type->getName(),
1978+
getName()),
1979+
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
1980+
1981+
19281982
auto col_res = ColumnString::create();
19291983
if (const auto * col_string = checkAndGetColumn<ColumnString>(column_string.get()))
19301984
{
19311985
if (column_length->isColumnConst())
19321986
{
19331987
// vector const
1934-
size_t length = getValueFromLengthField<LengthFieldType>((*column_length)[0]);
1988+
size_t length = FunctionSubstringUTF8::getValueFromLengthColumn<LengthFieldType>(
1989+
*column_vector_length,
1990+
0);
19351991

19361992
// for const 0, return const blank string.
19371993
if (0 == length)
@@ -1951,8 +2007,10 @@ class FunctionRightUTF8 : public IFunction
19512007
else
19522008
{
19532009
// vector vector
1954-
auto get_length_func = [&column_length](size_t i) {
1955-
return getValueFromLengthField<LengthFieldType>((*column_length)[i]);
2010+
auto get_length_func = [column_vector_length](size_t i) {
2011+
return FunctionSubstringUTF8::getValueFromLengthColumn<LengthFieldType>(
2012+
*column_vector_length,
2013+
i);
19562014
};
19572015
RightUTF8Impl::vectorVector(
19582016
col_string->getChars(),
@@ -1971,8 +2029,10 @@ class FunctionRightUTF8 : public IFunction
19712029
assert(col_string_from_const);
19722030
// When useDefaultImplementationForConstants is true, string and length are not both constants
19732031
assert(!column_length->isColumnConst());
1974-
auto get_length_func = [&column_length](size_t i) {
1975-
return getValueFromLengthField<LengthFieldType>((*column_length)[i]);
2032+
auto get_length_func = [column_vector_length](size_t i) {
2033+
return FunctionSubstringUTF8::getValueFromLengthColumn<LengthFieldType>(
2034+
*column_vector_length,
2035+
i);
19762036
};
19772037
RightUTF8Impl::constVector(
19782038
column_length->size(),
@@ -1999,22 +2059,15 @@ class FunctionRightUTF8 : public IFunction
19992059
template <typename F>
20002060
static bool getLengthType(DataTypePtr type, F && f)
20012061
{
2002-
return castTypeToEither<DataTypeInt64, DataTypeUInt64>(type.get(), std::forward<F>(f));
2003-
}
2004-
2005-
template <typename Integer>
2006-
static size_t getValueFromLengthField(const Field & length_field)
2007-
{
2008-
if constexpr (std::is_same_v<Integer, Int64>)
2009-
{
2010-
Int64 signed_length = length_field.get<Int64>();
2011-
return signed_length < 0 ? 0 : signed_length;
2012-
}
2013-
else
2014-
{
2015-
static_assert(std::is_same_v<Integer, UInt64>);
2016-
return length_field.get<UInt64>();
2017-
}
2062+
return castTypeToEither<
2063+
DataTypeUInt8,
2064+
DataTypeUInt16,
2065+
DataTypeUInt32,
2066+
DataTypeUInt64,
2067+
DataTypeInt8,
2068+
DataTypeInt16,
2069+
DataTypeInt32,
2070+
DataTypeInt64>(type.get(), std::forward<F>(f));
20182071
}
20192072
};
20202073

0 commit comments

Comments
 (0)