Skip to content

Commit 4d96a71

Browse files
authored
Merge pull request ClickHouse#78565 from bigo-sg/opt_count_if
Trivial optimization: do not rewrite count(if()) to countIf if CAST is required
2 parents 804885e + 097799b commit 4d96a71

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

src/Analyzer/Passes/RewriteAggregateFunctionWithIfPass.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ class RewriteAggregateFunctionWithIfVisitor : public InDepthQueryTreeVisitorWith
7272

7373
QueryTreeNodes new_arguments{2};
7474

75-
/// We need to preserve the output type from if()
76-
if (if_arguments_nodes[1]->getResultType()->getName() != if_node->getResultType()->getName())
75+
/// We need to preserve the output type from if(). Notice that the return type of count() is the same either way
76+
if (if_arguments_nodes[1]->getResultType()->getName() != if_node->getResultType()->getName() && lower_name != "count")
7777
new_arguments[0] = createCastFunction(std::move(if_arguments_nodes[1]), if_node->getResultType(), getContext());
7878
else
7979
new_arguments[0] = std::move(if_arguments_nodes[1]);
@@ -99,7 +99,8 @@ class RewriteAggregateFunctionWithIfVisitor : public InDepthQueryTreeVisitorWith
9999

100100
QueryTreeNodes new_arguments{2};
101101

102-
if (if_arguments_nodes[2]->getResultType()->getName() != if_node->getResultType()->getName())
102+
/// We need to preserve the output type from if(). Notice that the return type of count() is the same either way
103+
if (if_arguments_nodes[2]->getResultType()->getName() != if_node->getResultType()->getName() && lower_name != "count")
103104
new_arguments[0] = createCastFunction(std::move(if_arguments_nodes[2]), if_node->getResultType(), getContext());
104105
else
105106
new_arguments[0] = std::move(if_arguments_nodes[2]);
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
<test>
2+
<query>select count(if(rand() % 2 = 0, 1, null)) from numbers(100000000) settings optimize_rewrite_aggregate_function_with_if = true</query>
3+
<query>select count(if(rand() % 2 = 0, null, 1)) from numbers(100000000) settings optimize_rewrite_aggregate_function_with_if = true</query>
4+
5+
<query>select count(if(rand() % 2 = 0, toNullable(1), null)) from numbers(100000000) settings optimize_rewrite_aggregate_function_with_if = true</query>
6+
<query>select count(if(rand() % 2 = 0, null, toNullable(1))) from numbers(100000000) settings optimize_rewrite_aggregate_function_with_if = true</query>
7+
</test>

0 commit comments

Comments
 (0)