77#include < Analyzer/Resolve/IdentifierResolver.h>
88#include < Analyzer/QueryNode.h>
99#include < Analyzer/TableNode.h>
10+ #include < Analyzer/UnionNode.h>
1011#include < Analyzer/FunctionNode.h>
1112#include < Analyzer/ColumnNode.h>
1213#include < Analyzer/InDepthQueryTreeVisitor.h>
@@ -34,10 +35,33 @@ namespace ErrorCodes
3435namespace
3536{
3637
37- struct HybridCastTask
38+ // / Collect Hybrid table expressions that require casts to normalize headers across segments.
39+ class HybridCastTablesCollector : public InDepthQueryTreeVisitor <HybridCastTablesCollector>
3840{
39- QueryTreeNodePtr table_expression;
40- ColumnsDescription cast_schema;
41+ public:
42+ explicit HybridCastTablesCollector (std::unordered_map<const IQueryTreeNode *, ColumnsDescription> & cast_map_)
43+ : cast_map(cast_map_)
44+ {}
45+
46+ static bool needChildVisit (QueryTreeNodePtr &, QueryTreeNodePtr &) { return true ; }
47+
48+ void visitImpl (QueryTreeNodePtr & node)
49+ {
50+ const auto * table = node->as <TableNode>();
51+ if (!table)
52+ return ;
53+
54+ const auto * storage = table->getStorage ().get ();
55+ if (const auto * distributed = typeid_cast<const StorageDistributed *>(storage))
56+ {
57+ ColumnsDescription to_cast = distributed->getColumnsToCast ();
58+ if (!to_cast.empty ())
59+ cast_map.emplace (node.get (), std::move (to_cast)); // repeated table_expression can overwrite
60+ }
61+ }
62+
63+ private:
64+ std::unordered_map<const IQueryTreeNode *, ColumnsDescription> & cast_map;
4165};
4266
4367// Visitor replaces all usages of the column with CAST(column, type) in the query tree.
@@ -55,8 +79,9 @@ class HybridCastVisitor : public InDepthQueryTreeVisitor<HybridCastVisitor>
5579
5680 static bool needChildVisit (QueryTreeNodePtr &, QueryTreeNodePtr & child)
5781 {
58- auto child_type = child->getNodeType ();
59- return !(child_type == QueryTreeNodeType::QUERY || child_type == QueryTreeNodeType::UNION);
82+ // / Traverse all child nodes so casts also apply inside subqueries and UNION branches.
83+ (void )child;
84+ return true ;
6085 }
6186
6287 void visitImpl (QueryTreeNodePtr & node)
@@ -79,7 +104,6 @@ class HybridCastVisitor : public InDepthQueryTreeVisitor<HybridCastVisitor>
79104 return ;
80105
81106 auto column_clone = std::static_pointer_cast<ColumnNode>(column_node->clone ());
82- column_clone->setColumnType (expected_column_opt->type );
83107
84108 auto cast_node = buildCastFunction (column_clone, expected_column_opt->type , context);
85109 const auto & alias = node->getAlias ();
@@ -99,38 +123,6 @@ class HybridCastVisitor : public InDepthQueryTreeVisitor<HybridCastVisitor>
99123
100124} // namespace
101125
102- void collectHybridTables (const QueryTreeNodePtr & join_tree, std::unordered_map<const IQueryTreeNode *, ColumnsDescription> & cast_map)
103- {
104- if (!join_tree)
105- return ;
106- if (const auto * table = join_tree->as <TableNode>())
107- {
108- const auto * storage = table->getStorage ().get ();
109- if (const auto * distributed = typeid_cast<const StorageDistributed *>(storage))
110- {
111- ColumnsDescription to_cast = distributed->getColumnsToCast ();
112- if (!to_cast.empty ())
113- cast_map.emplace (join_tree.get (), std::move (to_cast)); // repeated table_expression can overwrite
114- }
115- return ;
116- }
117- if (const auto * func = join_tree->as <FunctionNode>())
118- {
119- for (auto & child : func->getArguments ().getNodes ())
120- collectHybridTables (child, cast_map);
121- return ;
122- }
123- if (const auto * query = join_tree->as <QueryNode>())
124- {
125- collectHybridTables (query->getJoinTree (), cast_map);
126- }
127- if (const auto * union_node = join_tree->as <UnionNode>())
128- {
129- for (auto & child : union_node->getQueries ().getNodes ())
130- collectHybridTables (child, cast_map);
131- }
132- }
133-
134126void HybridCastsPass::run (QueryTreeNodePtr & query_tree_node, ContextPtr context)
135127{
136128 const auto & settings = context->getSettingsRef ();
@@ -142,7 +134,8 @@ void HybridCastsPass::run(QueryTreeNodePtr & query_tree_node, ContextPtr context
142134 return ;
143135
144136 std::unordered_map<const IQueryTreeNode *, ColumnsDescription> cast_map;
145- collectHybridTables (query->getJoinTree (), cast_map);
137+ HybridCastTablesCollector collector (cast_map);
138+ collector.visit (query_tree_node);
146139 if (cast_map.empty ())
147140 return ;
148141
0 commit comments