Skip to content

Commit e186d9e

Browse files
Merge pull request ClickHouse#76078 from ClickHouse/analyzer-scalar-correlated-subqueries
Support correlated subqueries in WHERE clause
2 parents e03c3bb + ffe3ccd commit e186d9e

File tree

69 files changed

+2001
-173
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+2001
-173
lines changed

src/Analyzer/HashUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ using QueryTreeNodeConstRawPtrWithHashSet = std::unordered_set<QueryTreeNodeCons
5454
template <typename Value>
5555
using QueryTreeNodePtrWithHashMap = std::unordered_map<QueryTreeNodePtrWithHash, Value>;
5656

57+
class ColumnNode;
58+
using ColumnNodePtr = std::shared_ptr<ColumnNode>;
59+
using ColumnNodePtrWithHash = QueryTreeNodeWithHash<ColumnNodePtr>;
60+
using ColumnNodePtrWithHashSet = std::unordered_set<ColumnNodePtrWithHash>;
61+
5762
}
5863

5964
template <typename T, bool compare_aliases, bool compare_types>

src/Analyzer/IQueryTreeNode.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ enum class QueryTreeNodeType : uint8_t
4747
ARRAY_JOIN,
4848
CROSS_JOIN,
4949
JOIN,
50-
UNION
50+
UNION,
5151
};
5252

5353
/// Convert query tree node type to string
@@ -91,12 +91,12 @@ class IQueryTreeNode : public TypePromotion<IQueryTreeNode>
9191
*/
9292
virtual DataTypePtr getResultType() const
9393
{
94-
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Method getResultType is not supported for {} query node", getNodeTypeName());
94+
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Method getResultType is not supported for {} query tree node", getNodeTypeName());
9595
}
9696

9797
virtual void convertToNullable()
9898
{
99-
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Method convertToNullable is not supported for {} query node", getNodeTypeName());
99+
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Method convertToNullable is not supported for {} query tree node", getNodeTypeName());
100100
}
101101

102102
struct CompareOptions

src/Analyzer/Passes/RemoveUnusedProjectionColumnsPass.cpp

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
#include <Functions/FunctionFactory.h>
44

5-
#include <Analyzer/InDepthQueryTreeVisitor.h>
5+
#include <Analyzer/AggregationUtils.h>
6+
#include <Analyzer/ColumnNode.h>
67
#include <Analyzer/FunctionNode.h>
8+
#include <Analyzer/InDepthQueryTreeVisitor.h>
79
#include <Analyzer/QueryNode.h>
8-
#include <Analyzer/ColumnNode.h>
910
#include <Analyzer/SortNode.h>
10-
#include <Analyzer/AggregationUtils.h>
11+
#include <Analyzer/UnionNode.h>
1112
#include <Analyzer/Utils.h>
1213

1314
namespace DB
@@ -48,6 +49,29 @@ class CollectUsedColumnsVisitor : public InDepthQueryTreeVisitorWithContext<Coll
4849
return;
4950
}
5051

52+
if (node_type == QueryTreeNodeType::FUNCTION)
53+
{
54+
auto & function_node = node->as<FunctionNode &>();
55+
56+
if (function_node.getFunctionName() != "exists")
57+
return;
58+
59+
const auto & subquery_argument = function_node.getArguments().getNodes().front();
60+
auto * query_node = subquery_argument->as<QueryNode>();
61+
auto * union_node = subquery_argument->as<UnionNode>();
62+
63+
const auto & correlated_columns = query_node != nullptr ? query_node->getCorrelatedColumns() : union_node->getCorrelatedColumns();
64+
for (const auto & correlated_column : correlated_columns)
65+
{
66+
auto * column_node = correlated_column->as<ColumnNode>();
67+
auto column_source_node = column_node->getColumnSource();
68+
auto column_source_node_type = column_source_node->getNodeType();
69+
if (column_source_node_type == QueryTreeNodeType::QUERY || column_source_node_type == QueryTreeNodeType::UNION)
70+
query_or_union_node_to_used_columns[column_source_node].insert(column_node->getColumnName());
71+
}
72+
return;
73+
}
74+
5175
if (node_type != QueryTreeNodeType::COLUMN)
5276
return;
5377

@@ -56,10 +80,15 @@ class CollectUsedColumnsVisitor : public InDepthQueryTreeVisitorWithContext<Coll
5680
return;
5781

5882
auto column_source_node = column_node.getColumnSource();
59-
auto column_source_node_type = column_source_node->getNodeType();
6083

61-
if (column_source_node_type == QueryTreeNodeType::QUERY || column_source_node_type == QueryTreeNodeType::UNION)
62-
query_or_union_node_to_used_columns[column_source_node].insert(column_node.getColumnName());
84+
auto it = query_or_union_node_to_used_columns.find(column_source_node);
85+
/// If the source node is not found in the map then:
86+
/// 1. Tt's either not a Query or Union node.
87+
/// 2. It's a correlated column and it comes from the outer scope.
88+
if (it != query_or_union_node_to_used_columns.end())
89+
{
90+
it->second.insert(column_node.getColumnName());
91+
}
6392
}
6493

6594
void reset()

src/Analyzer/QueryNode.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <Parsers/ASTSelectWithUnionQuery.h>
2323
#include <Parsers/ASTSetQuery.h>
2424

25+
#include <Analyzer/ColumnNode.h>
2526
#include <Analyzer/InterpolateNode.h>
2627
#include <Analyzer/UnionNode.h>
2728
#include <Analyzer/Utils.h>
@@ -46,6 +47,7 @@ QueryNode::QueryNode(ContextMutablePtr context_, SettingsChanges settings_change
4647
children[window_child_index] = std::make_shared<ListNode>();
4748
children[order_by_child_index] = std::make_shared<ListNode>();
4849
children[limit_by_child_index] = std::make_shared<ListNode>();
50+
children[correlated_columns_list_index] = std::make_shared<ListNode>();
4951
}
5052

5153
QueryNode::QueryNode(ContextMutablePtr context_)
@@ -108,6 +110,31 @@ void QueryNode::removeUnusedProjectionColumns(const std::unordered_set<size_t> &
108110
}
109111
}
110112

113+
ColumnNodePtrWithHashSet QueryNode::getCorrelatedColumnsSet() const
114+
{
115+
ColumnNodePtrWithHashSet result;
116+
117+
const auto & correlated_columns = getCorrelatedColumns().getNodes();
118+
result.reserve(correlated_columns.size());
119+
120+
for (const auto & column : correlated_columns)
121+
{
122+
result.insert(std::static_pointer_cast<ColumnNode>(column));
123+
}
124+
return result;
125+
}
126+
127+
void QueryNode::addCorrelatedColumn(ColumnNodePtr correlated_column)
128+
{
129+
auto & correlated_columns = getCorrelatedColumns().getNodes();
130+
for (const auto & column : correlated_columns)
131+
{
132+
if (column->isEqual(*correlated_column))
133+
return;
134+
}
135+
correlated_columns.push_back(correlated_column);
136+
}
137+
111138
void QueryNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
112139
{
113140
buffer << std::string(indent, ' ') << "QUERY id: " << format_state.getNodeId(this);
@@ -153,6 +180,12 @@ void QueryNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, s
153180
if (!cte_name.empty())
154181
buffer << ", cte_name: " << cte_name;
155182

183+
if (isCorrelated())
184+
{
185+
buffer << ", is_correlated: 1\n" << std::string(indent + 2, ' ') << "CORRELATED COLUMNS\n";
186+
getCorrelatedColumns().dumpTreeImpl(buffer, format_state, indent + 4);
187+
}
188+
156189
if (hasWith())
157190
{
158191
buffer << '\n' << std::string(indent + 2, ' ') << "WITH\n";

src/Analyzer/QueryNode.h

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <Core/NamesAndTypes.h>
66
#include <Core/Field.h>
77

8+
#include <Analyzer/HashUtils.h>
89
#include <Analyzer/IQueryTreeNode.h>
910
#include <Analyzer/ListNode.h>
1011
#include <Analyzer/TableExpressionModifiers.h>
@@ -59,6 +60,9 @@ namespace DB
5960
class QueryNode;
6061
using QueryNodePtr = std::shared_ptr<QueryNode>;
6162

63+
class ColumnNode;
64+
using ColumnNodePtr = std::shared_ptr<ColumnNode>;
65+
6266
class QueryNode final : public IQueryTreeNode
6367
{
6468
public:
@@ -619,6 +623,30 @@ class QueryNode final : public IQueryTreeNode
619623
/// Remove unused projection columns
620624
void removeUnusedProjectionColumns(const std::unordered_set<size_t> & used_projection_columns_indexes);
621625

626+
bool isCorrelated() const
627+
{
628+
return !children[correlated_columns_list_index]->as<ListNode>()->getNodes().empty();
629+
}
630+
631+
QueryTreeNodePtr & getCorrelatedColumnsNode()
632+
{
633+
return children[correlated_columns_list_index];
634+
}
635+
636+
ListNode & getCorrelatedColumns()
637+
{
638+
return children[correlated_columns_list_index]->as<ListNode &>();
639+
}
640+
641+
const ListNode & getCorrelatedColumns() const
642+
{
643+
return children[correlated_columns_list_index]->as<ListNode &>();
644+
}
645+
646+
ColumnNodePtrWithHashSet getCorrelatedColumnsSet() const;
647+
648+
void addCorrelatedColumn(ColumnNodePtr correlated_column);
649+
622650
QueryTreeNodeType getNodeType() const override
623651
{
624652
return QueryTreeNodeType::QUERY;
@@ -675,7 +703,8 @@ class QueryNode final : public IQueryTreeNode
675703
static constexpr size_t limit_by_child_index = 13;
676704
static constexpr size_t limit_child_index = 14;
677705
static constexpr size_t offset_child_index = 15;
678-
static constexpr size_t children_size = offset_child_index + 1;
706+
static constexpr size_t correlated_columns_list_index = 16;
707+
static constexpr size_t children_size = correlated_columns_list_index + 1;
679708
};
680709

681710
}

src/Analyzer/QueryTreePassManager.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ class ValidationChecker : public InDepthQueryTreeVisitor<ValidationChecker>
107107
if (isNameOfInFunction(function->getFunctionName()))
108108
return;
109109

110+
if (function->getFunctionName() == "exists")
111+
return;
112+
110113
const auto & expected_argument_types = function->getArgumentTypes();
111114
size_t expected_argument_types_size = expected_argument_types.size();
112115
auto actual_argument_columns = function->getArgumentColumns();

src/Analyzer/Resolve/QueryAnalyzer.cpp

Lines changed: 71 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <Functions/IFunctionAdaptors.h>
1919
#include <Functions/UserDefined/UserDefinedExecutableFunctionFactory.h>
2020
#include <Functions/UserDefined/UserDefinedSQLFunctionFactory.h>
21+
#include <Functions/exists.h>
2122
#include <Functions/grouping.h>
2223

2324
#include <TableFunctions/TableFunctionFactory.h>
@@ -113,6 +114,7 @@ namespace Setting
113114
extern const SettingsBool allow_suspicious_types_in_order_by;
114115
extern const SettingsBool allow_not_comparable_types_in_order_by;
115116
extern const SettingsBool use_concurrency_control;
117+
extern const SettingsBool allow_experimental_correlated_subqueries;
116118
extern const SettingsString implicit_table_at_top_level;
117119
}
118120

@@ -1380,12 +1382,14 @@ IdentifierResolveResult QueryAnalyzer::tryResolveIdentifierInParentScopes(const
13801382
{
13811383
auto current = nodes_to_process.back();
13821384
nodes_to_process.pop_back();
1383-
if (auto * current_column = current->as<ColumnNode>())
1385+
if (ColumnNodePtr current_column = std::dynamic_pointer_cast<ColumnNode>(current))
13841386
{
1385-
if (isDependentColumn(&scope, current_column->getColumnSource()))
1387+
auto is_correlated_column = checkCorrelatedColumn(&scope, current_column);
1388+
if (is_correlated_column && !scope.context->getSettingsRef()[Setting::allow_experimental_correlated_subqueries])
13861389
{
13871390
throw Exception(ErrorCodes::UNSUPPORTED_METHOD,
1388-
"Resolved identifier '{}' in parent scope to expression '{}' with correlated column '{}'. In scope {}",
1391+
"Resolved identifier '{}' in parent scope to expression '{}' with correlated column '{}'"
1392+
" (Enable 'allow_experimental_correlated_subqueries' setting to allow correlated subqueries execution). In scope {}",
13891393
identifier_lookup.identifier.getFullName(),
13901394
resolved_identifier->formatASTForErrorMessage(),
13911395
current_column->getColumnName(),
@@ -2848,27 +2852,6 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
28482852
}
28492853
}
28502854

2851-
if (is_special_function_exists)
2852-
{
2853-
checkFunctionNodeHasEmptyNullsAction(*function_node_ptr);
2854-
/// Rewrite EXISTS (subquery) into 1 IN (SELECT 1 FROM (subquery) LIMIT 1).
2855-
auto & exists_subquery_argument = function_node_ptr->getArguments().getNodes().at(0);
2856-
2857-
auto constant_data_type = std::make_shared<DataTypeUInt64>();
2858-
2859-
auto in_subquery = std::make_shared<QueryNode>(Context::createCopy(scope.context));
2860-
in_subquery->setIsSubquery(true);
2861-
in_subquery->getProjection().getNodes().push_back(std::make_shared<ConstantNode>(1UL, constant_data_type));
2862-
in_subquery->getJoinTree() = exists_subquery_argument;
2863-
in_subquery->getLimit() = std::make_shared<ConstantNode>(1UL, constant_data_type);
2864-
2865-
function_node_ptr = std::make_shared<FunctionNode>("in");
2866-
function_node_ptr->getArguments().getNodes() = {std::make_shared<ConstantNode>(1UL, constant_data_type), in_subquery};
2867-
node = function_node_ptr;
2868-
function_name = "in";
2869-
is_special_function_in = true;
2870-
}
2871-
28722855
if (is_special_function_if && !function_node_ptr->getArguments().getNodes().empty())
28732856
{
28742857
checkFunctionNodeHasEmptyNullsAction(*function_node_ptr);
@@ -2925,12 +2908,61 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
29252908
}
29262909

29272910
/// Resolve function arguments
2928-
bool allow_table_expressions = is_special_function_in;
2911+
bool allow_table_expressions = is_special_function_in || is_special_function_exists;
29292912
auto arguments_projection_names = resolveExpressionNodeList(function_node_ptr->getArgumentsNode(),
29302913
scope,
29312914
true /*allow_lambda_expression*/,
29322915
allow_table_expressions /*allow_table_expression*/);
29332916

2917+
if (is_special_function_exists)
2918+
{
2919+
checkFunctionNodeHasEmptyNullsAction(*function_node_ptr);
2920+
/// Rewrite EXISTS (subquery) into 1 IN (SELECT 1 FROM (subquery) LIMIT 1).
2921+
auto & exists_subquery_argument = function_node_ptr->getArguments().getNodes().at(0);
2922+
bool correlated_exists_subquery = exists_subquery_argument->getNodeType() == QueryTreeNodeType::QUERY
2923+
? exists_subquery_argument->as<QueryNode>()->isCorrelated()
2924+
: exists_subquery_argument->as<UnionNode>()->isCorrelated();
2925+
if (!correlated_exists_subquery)
2926+
{
2927+
auto constant_data_type = std::make_shared<DataTypeUInt64>();
2928+
2929+
auto in_subquery = std::make_shared<QueryNode>(Context::createCopy(scope.context));
2930+
in_subquery->setIsSubquery(true);
2931+
in_subquery->getProjection().getNodes().push_back(std::make_shared<ConstantNode>(1UL, constant_data_type));
2932+
in_subquery->getJoinTree() = exists_subquery_argument;
2933+
in_subquery->getLimit() = std::make_shared<ConstantNode>(1UL, constant_data_type);
2934+
2935+
function_node_ptr = std::make_shared<FunctionNode>("in");
2936+
function_node_ptr->getArguments().getNodes() = {
2937+
std::make_shared<ConstantNode>(1UL, constant_data_type),
2938+
std::move(in_subquery)
2939+
};
2940+
2941+
/// Resolve modified arguments
2942+
arguments_projection_names = resolveExpressionNodeList(function_node_ptr->getArgumentsNode(),
2943+
scope,
2944+
true /*allow_lambda_expression*/,
2945+
true /*allow_table_expression*/);
2946+
2947+
node = function_node_ptr;
2948+
function_name = "in";
2949+
is_special_function_in = true;
2950+
}
2951+
else
2952+
{
2953+
/// Subquery is correlated and EXISTS can not be replaced by IN function.
2954+
/// EXISTS function will be replated by JOIN during query planning.
2955+
auto function_exists = std::make_shared<FunctionExists>();
2956+
function_node_ptr->resolveAsFunction(
2957+
std::make_shared<FunctionToFunctionBaseAdaptor>(
2958+
function_exists, DataTypes{}, function_exists->getReturnTypeImpl({})
2959+
)
2960+
);
2961+
2962+
return { calculateFunctionProjectionName(node, parameters_projection_names, arguments_projection_names) };
2963+
}
2964+
}
2965+
29342966
/// Mask arguments if needed
29352967
if (!scope.context->getSettingsRef()[Setting::format_display_secrets_in_show_and_select])
29362968
{
@@ -2981,6 +3013,10 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
29813013
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function '{}' expects 2 arguments", function_name);
29823014

29833015
auto & in_second_argument = function_in_arguments_nodes[1];
3016+
if (isCorrelatedQueryOrUnionNode(function_in_arguments_nodes[0]) || isCorrelatedQueryOrUnionNode(function_in_arguments_nodes[1]))
3017+
throw Exception(ErrorCodes::NOT_IMPLEMENTED,
3018+
"Correlated subqueries are not supported as IN function arguments yet, but found in expression: {}",
3019+
node->formatASTForErrorMessage());
29843020
auto * table_node = in_second_argument->as<TableNode>();
29853021
auto * table_function_node = in_second_argument->as<TableFunctionNode>();
29863022

@@ -5143,6 +5179,16 @@ void QueryAnalyzer::resolveJoin(QueryTreeNodePtr & join_node, IdentifierResolveS
51435179
resolveQueryJoinTreeNode(join_node_typed.getRightTableExpression(), scope, expressions_visitor);
51445180
validateJoinTableExpressionWithoutAlias(join_node, join_node_typed.getRightTableExpression(), scope);
51455181

5182+
if (isCorrelatedQueryOrUnionNode(join_node_typed.getLeftTableExpression()))
5183+
throw Exception(ErrorCodes::NOT_IMPLEMENTED,
5184+
"Correlated subqueries are not supported in JOINs yet, but found in expression: {}",
5185+
join_node_typed.getLeftTableExpression()->formatASTForErrorMessage());
5186+
5187+
if (isCorrelatedQueryOrUnionNode(join_node_typed.getRightTableExpression()))
5188+
throw Exception(ErrorCodes::NOT_IMPLEMENTED,
5189+
"Correlated subqueries are not supported in JOINs yet, but found in expression: {}",
5190+
join_node_typed.getRightTableExpression()->formatASTForErrorMessage());
5191+
51465192
if (!join_node_typed.getLeftTableExpression()->hasAlias() && !join_node_typed.getRightTableExpression()->hasAlias())
51475193
checkDuplicateTableNamesOrAliasForPasteJoin(join_node_typed, scope);
51485194

src/Analyzer/TableNode.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,10 @@ void TableNode::updateTreeHashImpl(HashState & state, CompareOptions) const
8585
}
8686
else
8787
{
88-
auto full_name = storage_id.getFullNameNotQuoted();
88+
// In case of cross-replication we don't know what database is used for the table.
89+
// `storage_id.hasDatabase()` can return false only on the initiator node.
90+
// Each shard will use the default database (in the case of cross-replication shards may have different defaults).
91+
auto full_name = storage_id.hasDatabase() ? storage_id.getFullNameNotQuoted() : storage_id.getTableName();
8992
state.update(full_name.size());
9093
state.update(full_name);
9194
}

0 commit comments

Comments
 (0)