Skip to content

Commit e0e70a4

Browse files
authored
perf: opt nested agg calls (#4022)
* perf(expr pass): cache aggregate function call in single projection * test: add two test cases for aggregate calls 1. nested aggregate call 2. duplicated aggregate call * feat(lambdafy): cache same expression Initially we only intend to cache udaf calls under same window, this change seem to cache all expressions under correct context. This cache mechanism may not work since its not fully tested. * feat(expr pass): letify expression to avoid repeated agg calls in nested aggregate expression A nested agg function call, e.g. agg_fn1(col1, agg_fn2(col2)) Will evaluate like a let expression: LET fn2 = agg_fn2(col2) IN agg_fn1(col1, fn2) TODO: - fix agg call cache * feat(expr pass): equivalent agg call cache TOOD: - there are some cases still need fix * fix: set output to id node for equivalent agg calls
1 parent 3ea64b4 commit e0e70a4

File tree

15 files changed

+466
-82
lines changed

15 files changed

+466
-82
lines changed

cases/query/udaf_query.yaml

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,54 @@ cases:
216216
- [3, "aa", 1, 8]
217217
- [4, "aa", 0, 21]
218218

219+
- id: 6
220+
desc: |
221+
nested aggregate function call
222+
inputs:
223+
- columns: [ "id int","ts timestamp","group1 string","val1 string" ]
224+
indexs: [ "index1:group1:ts" ]
225+
name: t1
226+
data: |
227+
1, 1612130400000, g1, 1
228+
2, 1612130401000, g1, 2
229+
3, 1612130402000, g1, 4
230+
sql: |
231+
select
232+
`id`,
233+
`val1`,
234+
first_value(val1) over w1 as ft,
235+
count_where(id, val1 = first_value(val1)) over w1 as agg1,
236+
count_where(id, val1 != first_value(val1)) over w1 as agg2,
237+
from `t1` WINDOW
238+
w1 as (partition by `group1` order by `ts` rows_range between 5s preceding and 0s preceding)
239+
expect:
240+
columns: ["id int", "val1 string", "ft string", "agg1 int64", "agg2 int64"]
241+
order: id
242+
rows:
243+
- [1, 1, 1, 1, 0]
244+
- [2, 2, 1, 1, 1]
245+
- [3, 4, 1, 1, 2]
246+
- id: 7
247+
desc: |
248+
duplicated aggregate call
249+
inputs:
250+
- columns: [ "id int","ts timestamp","group1 string","val1 string" ]
251+
indexs: [ "index1:group1:ts" ]
252+
name: t1
253+
data: |
254+
1, 1612130400000, g1, 1
255+
2, 1612130401000, g1, 2
256+
sql: |
257+
select
258+
`id`,
259+
`val1`,
260+
first_value(val1) over w1 as agg1,
261+
first_value(val1) over w1 as agg2,
262+
from `t1` WINDOW
263+
w1 as (partition by `group1` order by `ts` rows_range between 5s preceding and 0s preceding)
264+
expect:
265+
columns: ["id int", "val1 string", "agg1 string", "agg2 string"]
266+
order: id
267+
rows:
268+
- [1, 1, "1", "1"]
269+
- [2, 2, "1", "1"]

hybridse/include/node/node_enum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ enum ExprType {
149149
kExprArray,
150150
kExprArrayElement, // extract value from a array or map, with `[]` operator
151151
kExprStructCtorParens, // (expr1, expr2, ...)
152+
kExprLet, // LET [bindings] in <output expr>
152153
kExprFake, // not a real one
153154
kExprLast = kExprFake,
154155
};

hybridse/include/node/sql_node.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,21 @@
1717
#ifndef HYBRIDSE_INCLUDE_NODE_SQL_NODE_H_
1818
#define HYBRIDSE_INCLUDE_NODE_SQL_NODE_H_
1919

20+
#include <absl/status/status.h>
2021
#include <iostream>
2122
#include <map>
2223
#include <memory>
2324
#include <string>
2425
#include <unordered_map>
26+
#include <utility>
2527
#include <vector>
2628

29+
#include "absl/container/flat_hash_map.h"
2730
#include "absl/status/statusor.h"
2831
#include "absl/strings/match.h"
2932
#include "absl/strings/str_cat.h"
3033
#include "absl/strings/string_view.h"
34+
#include "base/fe_status.h"
3135
#include "boost/algorithm/string.hpp"
3236
#include "boost/algorithm/string/predicate.hpp"
3337
#include "boost/filesystem/operations.hpp"
@@ -618,6 +622,7 @@ class ArrayElementExpr : public ExprNode {
618622
Status InferAttr(ExprAnalysisContext *ctx) override;
619623
};
620624

625+
621626
class FnNode : public SqlNode {
622627
public:
623628
FnNode() : SqlNode(kFn, 0, 0), indent(0) {}
@@ -1869,6 +1874,59 @@ class EscapedExpr : public ExprNode {
18691874
}
18701875
};
18711876

1877+
class LetExpr : public ExprNode {
1878+
public:
1879+
class LetCtxEntry {
1880+
public:
1881+
LetCtxEntry(ExprIdNode *id_node, ExprNode *expr, const FrameNode *frame)
1882+
: id_node(id_node), expr(expr), frame(frame) {}
1883+
1884+
ExprIdNode *id_node;
1885+
ExprNode *expr; // referred udaf call
1886+
const FrameNode *frame;
1887+
};
1888+
1889+
class LetContext {
1890+
public:
1891+
base::Status Append(ExprIdNode *k, ExprNode *v, const FrameNode *frame) {
1892+
auto it = cache.find(k);
1893+
if (it == cache.end()) {
1894+
bindings.emplace_back(k, v, frame);
1895+
cache.emplace(k, v);
1896+
} else {
1897+
CHECK_TRUE(v == it->second, common::kPlanError,
1898+
"let context: try mapping id node to two different resolved nodes");
1899+
}
1900+
1901+
return base::Status::OK();
1902+
}
1903+
1904+
bool empty() const noexcept { return bindings.empty(); }
1905+
1906+
// necessary let bindings, not all cached entry will appear in bindings
1907+
std::vector<LetCtxEntry> bindings;
1908+
absl::flat_hash_map<ExprIdNode *, ExprNode *> cache;
1909+
};
1910+
1911+
LetExpr(ExprNode *expr, const LetContext &ctx) : ExprNode(kExprLet), ctx_(ctx), expr_(expr) {
1912+
AddChild(expr);
1913+
}
1914+
~LetExpr() override {}
1915+
1916+
1917+
const LetContext& ctx() const {return ctx_; }
1918+
ExprNode *expr() const { return expr_; }
1919+
1920+
void Print(std::ostream &output, const std::string &org_tab) const override;
1921+
const std::string GetExprString() const override;
1922+
LetExpr *ShadowCopy(NodeManager *nm) const override;
1923+
Status InferAttr(ExprAnalysisContext *ctx) override;
1924+
1925+
private:
1926+
LetContext ctx_;
1927+
ExprNode* expr_;
1928+
};
1929+
18721930
class ResTarget : public SqlNode {
18731931
public:
18741932
ResTarget() : SqlNode(kResTarget, 0, 0), name_(""), val_(nullptr) {}

hybridse/include/vm/physical_op.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,8 @@ class ColumnProjects : public FnComponent {
330330
primary_frame_ = frame;
331331
}
332332

333+
std::vector<const node::FrameNode *> frames() const { return frames_; }
334+
333335
const node::FrameNode *GetPrimaryFrame() const { return primary_frame_; }
334336

335337
base::Status ReplaceExpr(const passes::ExprReplacer &replacer,

hybridse/src/codegen/fn_let_ir_builder.cc

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,31 @@ Status RowFnLetIRBuilder::Build(
109109
std::map<std::string, AggregateIRBuilder> window_agg_builder;
110110
uint32_t agg_builder_id = 0;
111111

112-
auto expr_list = compile_func->body();
112+
auto fn_body = compile_func->body();
113+
auto expr_list = fn_body;
114+
auto maybe_let_expr = fn_body->GetAsOrNull<node::LetExpr>();
115+
bool is_let_fn_body = maybe_let_expr != nullptr;
116+
if (is_let_fn_body) {
117+
expr_list = maybe_let_expr->expr();
118+
}
113119
CHECK_TRUE(project_frames.size() == expr_list->GetChildNum(), kCodegenError,
114120
"Frame num should match expr num");
115121

122+
if (is_let_fn_body) {
123+
for (auto& entry : maybe_let_expr->ctx().bindings) {
124+
auto key = entry.id_node;
125+
auto exp = entry.expr;
126+
auto frame = entry.frame;
127+
CHECK_STATUS(BindProjectFrame(&expr_ir_builder, frame, compile_func, ctx_->GetCurrentBlock(), sv));
128+
NativeValue exp_out;
129+
CHECK_STATUS(expr_ir_builder.Build(exp, &exp_out));
130+
131+
base::Status s;
132+
variable_ir_builder.StoreValue(key->GetExprString(), exp_out, s);
133+
CHECK_STATUS(s);
134+
}
135+
}
136+
116137
for (size_t i = 0; i < expr_list->GetChildNum(); ++i) {
117138
const ::hybridse::node::ExprNode* expr = expr_list->GetChild(i);
118139
auto frame = project_frames[i];

hybridse/src/node/expr_node.cc

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717
#include "node/expr_node.h"
1818

1919
#include "absl/strings/str_cat.h"
20+
#include "absl/strings/str_join.h"
2021
#include "absl/strings/substitute.h"
2122
#include "codec/fe_row_codec.h"
23+
#include "codegen/ir_base_builder.h"
2224
#include "node/node_manager.h"
2325
#include "node/sql_node.h"
2426
#include "passes/expression/expr_pass.h"
2527
#include "passes/resolve_fn_and_attrs.h"
2628
#include "vm/schemas_context.h"
27-
#include "codegen/ir_base_builder.h"
2829

2930
using ::hybridse::common::kTypeError;
3031

@@ -1231,6 +1232,30 @@ Status ArrayElementExpr::InferAttr(ExprAnalysisContext* ctx) {
12311232
ExprNode *ArrayElementExpr::array() const { return GetChild(0); }
12321233
ExprNode *ArrayElementExpr::position() const { return GetChild(1); }
12331234

1235+
1236+
void LetExpr::Print(std::ostream &output, const std::string &org_tab) const {}
1237+
const std::string LetExpr::GetExprString() const {
1238+
return absl::Substitute("LET $0 IN $1",
1239+
absl::StrJoin(ctx_.bindings, ",",
1240+
[](std::string* out, const decltype(LetContext::bindings)::value_type& kv) {
1241+
absl::StrAppend(out, kv.id_node->GetExprString(), "=",
1242+
kv.expr->GetExprString());
1243+
}),
1244+
expr_->GetExprString());
1245+
}
1246+
LetExpr* LetExpr::ShadowCopy(NodeManager* nm) const { return nm->MakeNode<LetExpr>(expr_, ctx_); }
1247+
1248+
Status LetExpr::InferAttr(ExprAnalysisContext* ctx) {
1249+
for (auto& [k, v, _] : ctx_.bindings) {
1250+
CHECK_TRUE(k->GetOutputType() != nullptr, common::kTypeError, "expr id node not resolved");
1251+
CHECK_TRUE(k->GetOutputType() == v->GetOutputType(), common::kTypeError,
1252+
"expr id node return type does not match binding expr");
1253+
}
1254+
SetOutputType(expr_->GetOutputType());
1255+
SetNullable(expr_->nullable());
1256+
return {};
1257+
}
1258+
12341259
StructCtorWithParens* StructCtorWithParens::ShadowCopy(NodeManager* nm) const {
12351260
return nm->MakeNode<StructCtorWithParens>(fields());
12361261
}

hybridse/src/node/sql_node.cc

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -117,38 +117,39 @@ static const absl::flat_hash_map<DataType, absl::string_view>& GetDataTypeNamesM
117117
}
118118

119119
static absl::flat_hash_map<ExprType, absl::string_view> CreateExprTypeNamesMap() {
120-
absl::flat_hash_map<ExprType, absl::string_view> map = {
121-
{kExprPrimary, "primary"},
122-
{kExprParameter, "parameter"},
123-
{kExprId, "id"},
124-
{kExprBinary, "binary"},
125-
{kExprUnary, "unary"},
126-
{kExprCall, "function"},
127-
{kExprCase, "case"},
128-
{kExprWhen, "when"},
129-
{kExprBetween, "between"},
130-
{kExprColumnRef, "column ref"},
131-
{kExprColumnId, "column id"},
132-
{kExprCast, "cast"},
133-
{kExprAll, "all"},
134-
{kExprStruct, "struct"},
135-
{kExprQuery, "query"},
136-
{kExprOrder, "order"},
137-
{kExprGetField, "get field"},
138-
{kExprCond, "cond"},
139-
{kExprUnknow, "unknow"},
140-
{kExprIn, "in"},
141-
{kExprList, "expr_list"},
142-
{kExprForIn, "for_in"},
143-
{kExprRange, "range"},
144-
{kExprOrderExpression, "order"},
145-
{kExprEscaped, "escape"},
146-
{kExprArray, "array"},
147-
{kExprArrayElement, "array element"},
148-
{kExprStructCtorParens, "struct with parens"},
149-
};
150-
for (auto kind = 0; kind < ExprType::kExprLast; ++kind) {
151-
DCHECK(map.find(static_cast<ExprType>(kind)) != map.end());
120+
absl::flat_hash_map<ExprType, absl::string_view> map = {
121+
{kExprPrimary, "primary"},
122+
{kExprParameter, "parameter"},
123+
{kExprId, "id"},
124+
{kExprBinary, "binary"},
125+
{kExprUnary, "unary"},
126+
{kExprCall, "function"},
127+
{kExprCase, "case"},
128+
{kExprWhen, "when"},
129+
{kExprBetween, "between"},
130+
{kExprColumnRef, "column ref"},
131+
{kExprColumnId, "column id"},
132+
{kExprCast, "cast"},
133+
{kExprAll, "all"},
134+
{kExprStruct, "struct"},
135+
{kExprQuery, "query"},
136+
{kExprOrder, "order"},
137+
{kExprGetField, "get field"},
138+
{kExprCond, "cond"},
139+
{kExprUnknow, "unknow"},
140+
{kExprIn, "in"},
141+
{kExprList, "expr_list"},
142+
{kExprForIn, "for_in"},
143+
{kExprRange, "range"},
144+
{kExprOrderExpression, "order"},
145+
{kExprEscaped, "escape"},
146+
{kExprArray, "array"},
147+
{kExprArrayElement, "array element"},
148+
{kExprStructCtorParens, "struct with parens"},
149+
{kExprLet, "let"},
150+
};
151+
for (auto kind = 0; kind < ExprType::kExprLast; ++kind) {
152+
DCHECK(map.find(static_cast<ExprType>(kind)) != map.end());
152153
}
153154
return map;
154155
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/**
2+
* Copyright (c) 2025 Ace <[email protected]>
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "passes/expression/cache_expressions.h"
18+
19+
namespace hybridse {
20+
namespace passes {
21+
22+
static std::string CallExprKey(const node::CallExprNode* call);
23+
24+
std::string CallExprKey(const node::CallExprNode* call) {
25+
std::string str = call->GetFnDef()->GetName();
26+
str.append("(");
27+
for (size_t i = 0; i < call->children_.size(); ++i) {
28+
str.append(call->children_[i]->GetExprString());
29+
if (i < call->children_.size() - 1) {
30+
str.append(", ");
31+
}
32+
}
33+
str.append(")");
34+
35+
return str;
36+
}
37+
38+
base::Status CacheExpressions::Apply(node::ExprAnalysisContext* ctx, node::ExprNode* expr, node::ExprNode** out) {
39+
*out = expr;
40+
for (int i = 0; i < expr->GetChildNum(); ++i) {
41+
node::ExprNode* co = nullptr;
42+
CHECK_STATUS(Apply(ctx, expr->GetChild(i), &co));
43+
if (co != nullptr && co != expr->GetChild(i)) {
44+
expr->SetChild(i, co);
45+
}
46+
}
47+
48+
if (expr ->GetExprType() != node::kExprCall) {
49+
return {};
50+
}
51+
52+
auto call = expr->GetAsOrNull<node::CallExprNode>();
53+
if (call == nullptr) {
54+
return {};
55+
}
56+
57+
auto key = CallExprKey(call);
58+
59+
auto it = expr_cache_.find(key);
60+
if (it != expr_cache_.end()) {
61+
*out = it->second;
62+
} else {
63+
expr_cache_.emplace(key, call);
64+
}
65+
66+
return {};
67+
}
68+
69+
} // namespace passes
70+
} // namespace hybridse

0 commit comments

Comments
 (0)