Skip to content

Commit 3f65bfd

Browse files
authored
[refine](util) add cast_to_column to cast ColumnPtr (#58092)
This is similar to std::dynamic_pointer_cast. ```C++ ColumnPtr column_ptr = ColumnHelper::create_column<DataTypeInt32>({1, 2, 3}); EXPECT_EQ(column_ptr->use_count(), 1); ColumnInt32::Ptr column_i32 = cast_to_column<ColumnInt32>(column_ptr); EXPECT_TRUE(column_i32); EXPECT_EQ(column_ptr->use_count(), 2); EXPECT_EQ(column_i32->use_count(), 2); MutableColumnPtr column_ptr = ColumnInt32::create(); EXPECT_EQ(column_ptr->use_count(), 1); ColumnInt32::MutablePtr column_i32 = cast_to_column<ColumnInt32>(std::move(column_ptr)); EXPECT_TRUE(column_i32); EXPECT_EQ(column_i32->use_count(), 1); ```
1 parent edbfb8f commit 3f65bfd

File tree

6 files changed

+58
-10
lines changed

6 files changed

+58
-10
lines changed

be/src/vec/columns/column.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "common/status.h"
3131
#include "olap/olap_common.h"
3232
#include "runtime/define_primitive_type.h"
33+
#include "vec/common/assert_cast.h"
3334
#include "vec/common/cow.h"
3435
#include "vec/common/pod_array_fwd.h"
3536
#include "vec/common/string_ref.h"
@@ -759,7 +760,18 @@ ColumnType::Ptr check_and_get_column_ptr(const ColumnPtr& column) {
759760
if (raw_type_ptr == nullptr) {
760761
return nullptr;
761762
}
762-
return typename ColumnType::Ptr(raw_type_ptr);
763+
return ColumnType::cast_to_column_ptr(raw_type_ptr);
764+
}
765+
766+
template <typename ColumnType>
767+
ColumnType::Ptr cast_to_column(const ColumnPtr& column) {
768+
const ColumnType* raw_type_ptr = assert_cast<const ColumnType*>(column.get());
769+
return ColumnType::cast_to_column_ptr(raw_type_ptr);
770+
}
771+
template <typename ColumnType>
772+
ColumnType::MutablePtr cast_to_column(MutableColumnPtr column) {
773+
ColumnType* raw_type_ptr = assert_cast<ColumnType*>(column.get());
774+
return ColumnType::cast_to_column_mutptr(raw_type_ptr);
763775
}
764776

765777
/// True if column's an ColumnConst instance. It's just a syntax sugar for type check.

be/src/vec/common/cow.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,12 @@ class COWHelper : public Base {
417417
}
418418
#include "common/compile_check_avoid_end.h"
419419

420+
static Ptr cast_to_column_ptr(const Derived* raw_type_ptr) { return Ptr(raw_type_ptr); }
421+
422+
static MutablePtr cast_to_column_mutptr(Derived* raw_type_ptr) {
423+
return MutablePtr(raw_type_ptr);
424+
}
425+
420426
typename Base::MutablePtr clone() const override {
421427
return typename Base::MutablePtr(new Derived(static_cast<const Derived&>(*this)));
422428
}

be/src/vec/functions/dictionary.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ void IDictionary::load_values(const std::vector<ColumnPtr>& values_column) {
116116
using ValueRealDataType = std::decay_t<decltype(type)>;
117117
auto& att = _values_data[i];
118118
auto init_column_with_type = [&](auto& column_with_type) {
119-
column_with_type.column = value_column_without_nullable;
119+
using Type = std::decay_t<decltype(column_with_type)>::RealColumnType;
120+
column_with_type.column =
121+
cast_to_column<Type>(value_column_without_nullable);
120122
// if original value is nullable, the null_map must be not null
121123
if (values_column[i]->is_nullable()) {
122124
column_with_type.null_map =

be/src/vec/functions/dictionary.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,14 @@ class IDictionary {
132132
template <typename Type>
133133
struct ColumnWithType {
134134
// OutputColumnType is used as the result column type
135-
using OutputColumnType = Type::ColumnType;
136-
ColumnPtr column;
137135
ColumnPtr null_map;
138136
// RealColumnType is the real type of the column, as there may be ColumnString64, but the result column will not be ColumnString64
137+
138+
using OutputColumnType = Type::ColumnType;
139139
using RealColumnType = std::conditional_t<std::is_same_v<DictDataTypeString64, Type>,
140140
ColumnString64, OutputColumnType>;
141-
const RealColumnType* get() const {
142-
return assert_cast<const RealColumnType*, TypeCheckOnRelease::DISABLE>(column.get());
143-
}
141+
RealColumnType::Ptr column;
142+
const RealColumnType* get() const { return column.get(); }
144143

145144
const ColumnUInt8* get_null_map() const {
146145
if (!null_map) {

be/src/vec/functions/function.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ ColumnPtr wrap_in_nullable(const ColumnPtr& src, const Block& block, const Colum
4949
ColumnPtr src_not_nullable = src;
5050
MutableColumnPtr mutable_result_null_map_column;
5151

52-
if (const auto* nullable = check_and_get_column<ColumnNullable>(*src)) {
52+
if (auto nullable = check_and_get_column_ptr<ColumnNullable>(src)) {
5353
src_not_nullable = nullable->get_nested_column_ptr();
5454
result_null_map_column = nullable->get_null_map_column_ptr();
5555
}
@@ -60,8 +60,7 @@ ColumnPtr wrap_in_nullable(const ColumnPtr& src, const Block& block, const Colum
6060
continue;
6161
}
6262

63-
if (const auto* nullable = assert_cast<const ColumnNullable*>(elem.column.get());
64-
nullable->has_null()) {
63+
if (auto nullable = cast_to_column<ColumnNullable>(elem.column); nullable->has_null()) {
6564
const ColumnPtr& null_map_column = nullable->get_null_map_column_ptr();
6665
if (!result_null_map_column) { // NOLINT(bugprone-use-after-move)
6766
result_null_map_column = null_map_column->clone_resized(input_rows_count);

be/test/vec/columns/check_and_get_column_ptr_test.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include <gtest/gtest.h>
1919

20+
#include "runtime/primitive_type.h"
2021
#include "testutil/column_helper.h"
2122
#include "vec/columns/column.h"
2223
#include "vec/columns/column_nullable.h"
@@ -139,4 +140,33 @@ TEST(CheckAndGetColumnPtrTest, destructstest) {
139140

140141
EXPECT_EQ(column_ptr->use_count(), 1);
141142
}
143+
144+
TEST(CheckAndGetColumnPtrTest, cast_to_column_immut) {
145+
{
146+
ColumnPtr column_ptr = ColumnHelper::create_column<DataTypeInt32>({1, 2, 3});
147+
148+
EXPECT_EQ(column_ptr->use_count(), 1);
149+
ColumnInt32::Ptr column_i32 = cast_to_column<ColumnInt32>(column_ptr);
150+
151+
EXPECT_TRUE(column_i32);
152+
153+
EXPECT_EQ(column_ptr->use_count(), 2);
154+
155+
EXPECT_EQ(column_i32->use_count(), 2);
156+
}
157+
}
158+
159+
TEST(CheckAndGetColumnPtrTest, cast_to_column_mut) {
160+
{
161+
MutableColumnPtr column_ptr = ColumnInt32::create();
162+
163+
EXPECT_EQ(column_ptr->use_count(), 1);
164+
ColumnInt32::MutablePtr column_i32 = cast_to_column<ColumnInt32>(std::move(column_ptr));
165+
166+
EXPECT_TRUE(column_i32);
167+
168+
EXPECT_EQ(column_i32->use_count(), 1);
169+
}
170+
}
171+
142172
} // namespace doris::vectorized

0 commit comments

Comments
 (0)