Skip to content

Commit 90cbc5e

Browse files
committed
better
1 parent 2186269 commit 90cbc5e

File tree

2 files changed

+33
-39
lines changed

2 files changed

+33
-39
lines changed

src/Analyzer/Passes/HybridCastsPass.cpp

Lines changed: 32 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <Analyzer/Resolve/IdentifierResolver.h>
88
#include <Analyzer/QueryNode.h>
99
#include <Analyzer/TableNode.h>
10+
#include <Analyzer/UnionNode.h>
1011
#include <Analyzer/FunctionNode.h>
1112
#include <Analyzer/ColumnNode.h>
1213
#include <Analyzer/InDepthQueryTreeVisitor.h>
@@ -34,10 +35,33 @@ namespace ErrorCodes
3435
namespace
3536
{
3637

37-
struct HybridCastTask
38+
/// Collect Hybrid table expressions that require casts to normalize headers across segments.
39+
class HybridCastTablesCollector : public InDepthQueryTreeVisitor<HybridCastTablesCollector>
3840
{
39-
QueryTreeNodePtr table_expression;
40-
ColumnsDescription cast_schema;
41+
public:
42+
explicit HybridCastTablesCollector(std::unordered_map<const IQueryTreeNode *, ColumnsDescription> & cast_map_)
43+
: cast_map(cast_map_)
44+
{}
45+
46+
static bool needChildVisit(QueryTreeNodePtr &, QueryTreeNodePtr &) { return true; }
47+
48+
void visitImpl(QueryTreeNodePtr & node)
49+
{
50+
const auto * table = node->as<TableNode>();
51+
if (!table)
52+
return;
53+
54+
const auto * storage = table->getStorage().get();
55+
if (const auto * distributed = typeid_cast<const StorageDistributed *>(storage))
56+
{
57+
ColumnsDescription to_cast = distributed->getColumnsToCast();
58+
if (!to_cast.empty())
59+
cast_map.emplace(node.get(), std::move(to_cast)); // repeated table_expression can overwrite
60+
}
61+
}
62+
63+
private:
64+
std::unordered_map<const IQueryTreeNode *, ColumnsDescription> & cast_map;
4165
};
4266

4367
// Visitor replaces all usages of the column with CAST(column, type) in the query tree.
@@ -55,8 +79,9 @@ class HybridCastVisitor : public InDepthQueryTreeVisitor<HybridCastVisitor>
5579

5680
static bool needChildVisit(QueryTreeNodePtr &, QueryTreeNodePtr & child)
5781
{
58-
auto child_type = child->getNodeType();
59-
return !(child_type == QueryTreeNodeType::QUERY || child_type == QueryTreeNodeType::UNION);
82+
/// Traverse all child nodes so casts also apply inside subqueries and UNION branches.
83+
(void)child;
84+
return true;
6085
}
6186

6287
void visitImpl(QueryTreeNodePtr & node)
@@ -79,7 +104,6 @@ class HybridCastVisitor : public InDepthQueryTreeVisitor<HybridCastVisitor>
79104
return;
80105

81106
auto column_clone = std::static_pointer_cast<ColumnNode>(column_node->clone());
82-
column_clone->setColumnType(expected_column_opt->type);
83107

84108
auto cast_node = buildCastFunction(column_clone, expected_column_opt->type, context);
85109
const auto & alias = node->getAlias();
@@ -99,38 +123,6 @@ class HybridCastVisitor : public InDepthQueryTreeVisitor<HybridCastVisitor>
99123

100124
} // namespace
101125

102-
void collectHybridTables(const QueryTreeNodePtr & join_tree, std::unordered_map<const IQueryTreeNode *, ColumnsDescription> & cast_map)
103-
{
104-
if (!join_tree)
105-
return;
106-
if (const auto * table = join_tree->as<TableNode>())
107-
{
108-
const auto * storage = table->getStorage().get();
109-
if (const auto * distributed = typeid_cast<const StorageDistributed *>(storage))
110-
{
111-
ColumnsDescription to_cast = distributed->getColumnsToCast();
112-
if (!to_cast.empty())
113-
cast_map.emplace(join_tree.get(), std::move(to_cast)); // repeated table_expression can overwrite
114-
}
115-
return;
116-
}
117-
if (const auto * func = join_tree->as<FunctionNode>())
118-
{
119-
for (auto & child : func->getArguments().getNodes())
120-
collectHybridTables(child, cast_map);
121-
return;
122-
}
123-
if (const auto * query = join_tree->as<QueryNode>())
124-
{
125-
collectHybridTables(query->getJoinTree(), cast_map);
126-
}
127-
if (const auto * union_node = join_tree->as<UnionNode>())
128-
{
129-
for (auto & child : union_node->getQueries().getNodes())
130-
collectHybridTables(child, cast_map);
131-
}
132-
}
133-
134126
void HybridCastsPass::run(QueryTreeNodePtr & query_tree_node, ContextPtr context)
135127
{
136128
const auto & settings = context->getSettingsRef();
@@ -142,7 +134,8 @@ void HybridCastsPass::run(QueryTreeNodePtr & query_tree_node, ContextPtr context
142134
return;
143135

144136
std::unordered_map<const IQueryTreeNode *, ColumnsDescription> cast_map;
145-
collectHybridTables(query->getJoinTree(), cast_map);
137+
HybridCastTablesCollector collector(cast_map);
138+
collector.visit(query_tree_node);
146139
if (cast_map.empty())
147140
return;
148141

src/Storages/StorageDistributed.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@
120120
#include <cassert>
121121
#include <boost/algorithm/string/find_iterator.hpp>
122122
#include <boost/algorithm/string/finder.hpp>
123+
#include <fmt/ranges.h>
123124

124125

125126
namespace fs = std::filesystem;

0 commit comments

Comments
 (0)