Skip to content

Commit b7d1a38

Browse files
Merge pull request ClickHouse#79600 from ClickHouse/analyzer-correlated-scalar
Support scalar correlated subqueries in WHERE clause
2 parents 144bb7e + b67fdae commit b7d1a38

21 files changed

+800
-51
lines changed

src/Analyzer/QueryNode.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <memory>
12
#include <Analyzer/QueryNode.h>
23

34
#include <fmt/core.h>
@@ -34,6 +35,7 @@ namespace ErrorCodes
3435
{
3536
extern const int LOGICAL_ERROR;
3637
extern const int BAD_ARGUMENTS;
38+
extern const int UNSUPPORTED_METHOD;
3739
}
3840

3941
QueryNode::QueryNode(ContextMutablePtr context_, SettingsChanges settings_changes_)
@@ -135,6 +137,22 @@ void QueryNode::addCorrelatedColumn(const QueryTreeNodePtr & correlated_column)
135137
correlated_columns.push_back(correlated_column);
136138
}
137139

140+
DataTypePtr QueryNode::getResultType() const
141+
{
142+
if (isCorrelated())
143+
{
144+
if (projection_columns.size() == 1)
145+
{
146+
return projection_columns[0].type;
147+
}
148+
else
149+
throw Exception(ErrorCodes::UNSUPPORTED_METHOD,
150+
"Method getResultType is supported only for correlated query node with 1 column, but got {}",
151+
projection_columns.size());
152+
}
153+
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Method getResultType is supported only for correlated query node");
154+
}
155+
138156
void QueryNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
139157
{
140158
buffer << std::string(indent, ' ') << "QUERY id: " << format_state.getNodeId(this);

src/Analyzer/QueryNode.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,10 @@ class QueryNode final : public IQueryTreeNode
647647

648648
void addCorrelatedColumn(const QueryTreeNodePtr & correlated_column);
649649

650+
/// Returns result type of projection expression if query is correlated
651+
/// or throws an exception otherwise.
652+
DataTypePtr getResultType() const override;
653+
650654
QueryTreeNodeType getNodeType() const override
651655
{
652656
return QueryTreeNodeType::QUERY;

src/Analyzer/Resolve/QueryAnalyzer.cpp

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,11 @@ void QueryAnalyzer::evaluateScalarSubqueryIfNeeded(QueryTreeNodePtr & node, Iden
547547
node->getNodeTypeName(),
548548
node->formatASTForErrorMessage());
549549

550+
bool is_correlated_subquery = (query_node != nullptr && query_node->isCorrelated())
551+
|| (union_node != nullptr && union_node->isCorrelated());
552+
if (is_correlated_subquery)
553+
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Cannot evaluate correlated scalar subquery");
554+
550555
auto & context = scope.context;
551556

552557
Block scalar_block;
@@ -2533,7 +2538,8 @@ ProjectionName QueryAnalyzer::resolveWindow(QueryTreeNodePtr & node, IdentifierR
25332538
auto & window_node = node->as<WindowNode &>();
25342539
window_node.setParentWindowName({});
25352540

2536-
ProjectionNames partition_by_projection_names = resolveExpressionNodeList(window_node.getPartitionByNode(),
2541+
ProjectionNames partition_by_projection_names = resolveExpressionNodeList(
2542+
window_node.getPartitionByNode(),
25372543
scope,
25382544
false /*allow_lambda_expression*/,
25392545
false /*allow_table_expression*/);
@@ -2743,7 +2749,8 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
27432749

27442750
/// Resolve function parameters
27452751

2746-
auto parameters_projection_names = resolveExpressionNodeList(function_node_ptr->getParametersNode(),
2752+
auto parameters_projection_names = resolveExpressionNodeList(
2753+
function_node_ptr->getParametersNode(),
27472754
scope,
27482755
false /*allow_lambda_expression*/,
27492756
false /*allow_table_expression*/);
@@ -2924,7 +2931,8 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
29242931

29252932
/// Resolve function arguments
29262933
bool allow_table_expressions = is_special_function_in || is_special_function_exists;
2927-
auto arguments_projection_names = resolveExpressionNodeList(function_node_ptr->getArgumentsNode(),
2934+
auto arguments_projection_names = resolveExpressionNodeList(
2935+
function_node_ptr->getArgumentsNode(),
29282936
scope,
29292937
true /*allow_lambda_expression*/,
29302938
allow_table_expressions /*allow_table_expression*/);
@@ -3929,7 +3937,11 @@ ProjectionNames QueryAnalyzer::resolveExpressionNode(
39293937
else
39303938
resolveUnion(node, subquery_scope);
39313939

3932-
if (!allow_table_expression)
3940+
bool is_correlated_subquery = node_type == QueryTreeNodeType::QUERY
3941+
? node->as<QueryNode>()->isCorrelated()
3942+
: node->as<UnionNode>()->isCorrelated();
3943+
3944+
if (!allow_table_expression && !is_correlated_subquery)
39333945
evaluateScalarSubqueryIfNeeded(node, subquery_scope);
39343946

39353947
if (result_projection_names.empty())
@@ -4022,7 +4034,12 @@ ProjectionNames QueryAnalyzer::resolveExpressionNode(
40224034
* Example: CREATE TABLE test_table (id UInt64, value UInt64) ENGINE=TinyLog; SELECT plus(*) FROM test_table;
40234035
* Example: SELECT *** FROM system.one;
40244036
*/
4025-
ProjectionNames QueryAnalyzer::resolveExpressionNodeList(QueryTreeNodePtr & node_list, IdentifierResolveScope & scope, bool allow_lambda_expression, bool allow_table_expression)
4037+
ProjectionNames QueryAnalyzer::resolveExpressionNodeList(
4038+
QueryTreeNodePtr & node_list,
4039+
IdentifierResolveScope & scope,
4040+
bool allow_lambda_expression,
4041+
bool allow_table_expression
4042+
)
40264043
{
40274044
auto & node_list_typed = node_list->as<ListNode &>();
40284045
size_t node_list_size = node_list_typed.getNodes().size();

src/Interpreters/Context.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5529,10 +5529,10 @@ void Context::setGoogleProtosPath(const String & path)
55295529
shared->google_protos_path = path;
55305530
}
55315531

5532-
Context::SampleBlockCache & Context::getSampleBlockCache() const
5532+
std::pair<Context::SampleBlockCache *, std::unique_lock<std::mutex>> Context::getSampleBlockCache() const
55335533
{
55345534
assert(hasQueryContext());
5535-
return getQueryContext()->sample_block_cache;
5535+
return std::make_pair(&getQueryContext()->sample_block_cache, std::unique_lock(getQueryContext()->sample_block_cache_mutex));
55365536
}
55375537

55385538

src/Interpreters/Context.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,7 @@ class ContextData
534534
protected:
535535
using SampleBlockCache = std::unordered_map<std::string, Block>;
536536
mutable SampleBlockCache sample_block_cache;
537+
mutable std::mutex sample_block_cache_mutex;
537538

538539
PartUUIDsPtr part_uuids; /// set of parts' uuids, is used for query parts deduplication
539540
PartUUIDsPtr ignored_part_uuids; /// set of parts' uuids are meant to be excluded from query processing
@@ -1405,7 +1406,7 @@ class Context: public ContextData, public std::enable_shared_from_this<Context>
14051406
String getGoogleProtosPath() const;
14061407
void setGoogleProtosPath(const String & path);
14071408

1408-
SampleBlockCache & getSampleBlockCache() const;
1409+
std::pair<Context::SampleBlockCache *, std::unique_lock<std::mutex>> getSampleBlockCache() const;
14091410

14101411
/// Query parameters for prepared statements.
14111412
bool hasQueryParameters() const;

src/Interpreters/InterpreterSelectWithUnionQuery.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,20 +283,23 @@ Block InterpreterSelectWithUnionQuery::getSampleBlock(const ASTPtr & query_ptr_,
283283
return InterpreterSelectWithUnionQuery(query_ptr_, context_, std::move(options.analyze())).getSampleBlock();
284284
}
285285

286-
auto & cache = context_->getSampleBlockCache();
287286
/// Using query string because query_ptr changes for every internal SELECT
288287
auto key = query_ptr_->formatWithSecretsOneLine();
289-
if (cache.find(key) != cache.end())
290288
{
291-
return cache[key];
289+
auto [cache, lock] = context_->getSampleBlockCache();
290+
if (cache->find(key) != cache->end())
291+
return cache->at(key);
292292
}
293293

294294
SelectQueryOptions options;
295295
if (is_subquery)
296296
options = options.subquery();
297297
if (is_create_parameterized_view)
298298
options = options.createParameterizedView();
299-
return cache[key] = InterpreterSelectWithUnionQuery(query_ptr_, context_, std::move(options.analyze())).getSampleBlock();
299+
300+
auto sample_block = InterpreterSelectWithUnionQuery(query_ptr_, context_, std::move(options.analyze())).getSampleBlock();
301+
auto [cache, lock] = context_->getSampleBlockCache();
302+
return (*cache)[key] = sample_block;
300303
}
301304

302305

src/Interpreters/JoinInfo.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,16 @@ JoinActionRef JoinActionRef::deserialize(ReadBuffer & in, const ActionsDAGRawPtr
356356
return res;
357357
}
358358

359+
JoinActionRef JoinActionRef::clone(const ActionsDAG * actions_dag_) const
360+
{
361+
return JoinActionRef{actions_dag_, column_name};
362+
}
363+
364+
JoinActionRef::JoinActionRef(const ActionsDAG * actions_dag_, const String & column_name_)
365+
: actions_dag(actions_dag_)
366+
, column_name(column_name_)
367+
{}
368+
359369
void JoinPredicate::serialize(WriteBuffer & out, const JoinActionRef::ActionsDAGRawPtrs & dags) const
360370
{
361371
serializePredicateOperator(op, out);
@@ -435,6 +445,40 @@ JoinCondition JoinCondition::deserialize(ReadBuffer & in, const JoinActionRef::A
435445
};
436446
}
437447

448+
JoinCondition JoinCondition::clone(const JoinExpressionActions & expression_actions) const
449+
{
450+
JoinCondition copy;
451+
452+
copy.predicates.reserve(predicates.size());
453+
for (const auto & predicate : predicates)
454+
{
455+
copy.predicates.emplace_back(
456+
predicate.left_node.clone(expression_actions.left_pre_join_actions.get()),
457+
predicate.right_node.clone(expression_actions.right_pre_join_actions.get()),
458+
predicate.op);
459+
}
460+
461+
copy.left_filter_conditions.reserve(left_filter_conditions.size());
462+
for (const auto & condition: left_filter_conditions)
463+
{
464+
copy.left_filter_conditions.emplace_back(condition.clone(expression_actions.left_pre_join_actions.get()));
465+
}
466+
467+
copy.right_filter_conditions.reserve(right_filter_conditions.size());
468+
for (const auto & condition: right_filter_conditions)
469+
{
470+
copy.right_filter_conditions.emplace_back(condition.clone(expression_actions.right_pre_join_actions.get()));
471+
}
472+
473+
copy.residual_conditions.reserve(residual_conditions.size());
474+
for (const auto & condition: residual_conditions)
475+
{
476+
copy.residual_conditions.emplace_back(condition.clone(expression_actions.post_join_actions.get()));
477+
}
478+
479+
return copy;
480+
}
481+
438482
void JoinExpression::serialize(WriteBuffer & out, const JoinActionRef::ActionsDAGRawPtrs & dags) const
439483
{
440484
UInt8 is_using_flag = is_using ? 1 : 0;
@@ -473,6 +517,20 @@ JoinExpression JoinExpression::deserialize(ReadBuffer & in, const JoinActionRef:
473517
return {std::move(condition), std::move(disjunctive_conditions), bool(is_using_flag)};
474518
}
475519

520+
JoinExpression JoinExpression::clone(const JoinExpressionActions & expression_copy) const
521+
{
522+
JoinExpression copy;
523+
copy.condition = condition.clone(expression_copy);
524+
525+
copy.disjunctive_conditions.reserve(disjunctive_conditions.size());
526+
for (const auto & disjunctive_condition : disjunctive_conditions)
527+
copy.disjunctive_conditions.emplace_back(disjunctive_condition.clone(expression_copy));
528+
529+
copy.is_using = is_using;
530+
531+
return copy;
532+
}
533+
476534
void JoinInfo::serialize(WriteBuffer & out, const JoinActionRef::ActionsDAGRawPtrs & dags) const
477535
{
478536
expression.serialize(out, dags);

src/Interpreters/JoinInfo.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,20 @@ struct JoinExpressionActions
7575
{
7676
}
7777

78+
JoinExpressionActions clone() const
79+
{
80+
return JoinExpressionActions(
81+
std::make_unique<ActionsDAG>(left_pre_join_actions->clone()),
82+
std::make_unique<ActionsDAG>(right_pre_join_actions->clone()),
83+
std::make_unique<ActionsDAG>(post_join_actions->clone())
84+
);
85+
}
86+
87+
bool hasCorrelatedExpressions() const noexcept
88+
{
89+
return left_pre_join_actions->hasCorrelatedColumns() || right_pre_join_actions->hasCorrelatedColumns() || post_join_actions->hasCorrelatedColumns();
90+
}
91+
7892
ActionsDAGPtr left_pre_join_actions;
7993
ActionsDAGPtr right_pre_join_actions;
8094
ActionsDAGPtr post_join_actions;
@@ -103,7 +117,11 @@ class JoinActionRef
103117
void serialize(WriteBuffer & out, const ActionsDAGRawPtrs & dags) const;
104118
static JoinActionRef deserialize(ReadBuffer & in, const ActionsDAGRawPtrs & dags);
105119

120+
JoinActionRef clone(const ActionsDAG * actions_dag_) const;
121+
106122
private:
123+
JoinActionRef(const ActionsDAG * actions_dag_, const String & column_name_);
124+
107125
const ActionsDAG * actions_dag = nullptr;
108126
String column_name;
109127
};
@@ -136,6 +154,8 @@ struct JoinCondition
136154

137155
void serialize(WriteBuffer & out, const JoinActionRef::ActionsDAGRawPtrs & dags) const;
138156
static JoinCondition deserialize(ReadBuffer & in, const JoinActionRef::ActionsDAGRawPtrs & dags);
157+
158+
JoinCondition clone(const JoinExpressionActions & expression_actions) const;
139159
};
140160

141161
struct JoinExpression
@@ -154,6 +174,8 @@ struct JoinExpression
154174

155175
void serialize(WriteBuffer & out, const JoinActionRef::ActionsDAGRawPtrs & dags) const;
156176
static JoinExpression deserialize(ReadBuffer & in, const JoinActionRef::ActionsDAGRawPtrs & dags);
177+
178+
JoinExpression clone(const JoinExpressionActions & expression_copy) const;
157179
};
158180

159181
struct JoinInfo
@@ -170,6 +192,11 @@ struct JoinInfo
170192
/// The locality of the join (e.g., LOCAL, GLOBAL)
171193
JoinLocality locality;
172194

195+
JoinInfo clone(const JoinExpressionActions & expression_actions) const
196+
{
197+
return JoinInfo{ expression.clone(expression_actions), kind, strictness, locality};
198+
}
199+
173200
void serialize(WriteBuffer & out, const JoinActionRef::ActionsDAGRawPtrs & dags) const;
174201
static JoinInfo deserialize(ReadBuffer & in, const JoinActionRef::ActionsDAGRawPtrs & dags);
175202
};

src/Planner/CollectColumnIdentifiers.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
#include <Planner/CollectColumnIdentifiers.h>
22

3-
#include <Analyzer/InDepthQueryTreeVisitor.h>
43
#include <Analyzer/ColumnNode.h>
4+
#include <Analyzer/FunctionNode.h>
5+
#include <Analyzer/InDepthQueryTreeVisitor.h>
6+
#include <Analyzer/QueryNode.h>
7+
#include <Analyzer/UnionNode.h>
8+
#include <Analyzer/Utils.h>
59

610
#include <Planner/PlannerContext.h>
711

@@ -35,6 +39,30 @@ class CollectTopLevelColumnIdentifiersVisitor : public ConstInDepthQueryTreeVisi
3539

3640
void visitImpl(const QueryTreeNodePtr & node)
3741
{
42+
if (node->getNodeType() == QueryTreeNodeType::FUNCTION)
43+
{
44+
auto * function_node = node->as<FunctionNode>();
45+
for (const auto & argument : function_node->getArguments().getNodes())
46+
{
47+
if (!isCorrelatedQueryOrUnionNode(argument))
48+
continue;
49+
50+
auto * query_node = argument->as<QueryNode>();
51+
auto * union_node = argument->as<UnionNode>();
52+
53+
const auto & correlated_columns = query_node != nullptr ? query_node->getCorrelatedColumns() : union_node->getCorrelatedColumns();
54+
for (const auto & column : correlated_columns)
55+
{
56+
const auto * column_identifier = planner_context->getColumnNodeIdentifierOrNull(column);
57+
if (!column_identifier)
58+
return;
59+
60+
used_identifiers.insert(*column_identifier);
61+
}
62+
}
63+
return;
64+
}
65+
3866
if (node->getNodeType() != QueryTreeNodeType::COLUMN)
3967
return;
4068

0 commit comments

Comments
 (0)