@@ -204,19 +204,14 @@ void propagate_rex_input_renumber(
204
204
work_set.pop_back ();
205
205
auto modified_node = const_cast <hdk::ir::Node*>(node);
206
206
if (auto filter = dynamic_cast <hdk::ir::Filter*>(modified_node)) {
207
- auto new_condition_expr = visitor.visit (filter->getConditionExpr ());
208
- filter->setCondition (std::move (new_condition_expr));
207
+ filter->rewriteExprs (visitor);
209
208
auto usrs_it = du_web.find (filter);
210
209
CHECK (usrs_it != du_web.end () && usrs_it->second .size () == 1 );
211
210
work_set.push_back (*usrs_it->second .begin ());
212
211
continue ;
213
212
}
214
213
if (auto project = dynamic_cast <hdk::ir::Project*>(modified_node)) {
215
- hdk::ir::ExprPtrVector new_exprs;
216
- for (size_t i = 0 ; i < project->size (); ++i) {
217
- new_exprs.push_back (visitor.visit (project->getExpr (i).get ()));
218
- }
219
- project->setExpressions (std::move (new_exprs));
214
+ project->rewriteExprs (visitor);
220
215
continue ;
221
216
}
222
217
CHECK (false );
@@ -1057,13 +1052,9 @@ void propagate_input_renumbering(
1057
1052
work_set.pop_front ();
1058
1053
CHECK (!dynamic_cast <const hdk::ir::Scan*>(walker));
1059
1054
auto node = const_cast <hdk::ir::Node*>(walker);
1060
- if (auto project = dynamic_cast <hdk::ir::Project*>(node)) {
1061
- hdk::ir::ExprPtrVector new_exprs;
1062
- new_exprs.reserve (project->size ());
1063
- for (auto & expr : project->getExprs ()) {
1064
- new_exprs.emplace_back (visitor.visit (expr.get ()));
1065
- }
1066
- project->setExpressions (std::move (new_exprs));
1055
+ if (node->is <hdk::ir::Project>() || node->is <hdk::ir::Join>() ||
1056
+ node->is <hdk::ir::Filter>()) {
1057
+ node->rewriteExprs (visitor);
1067
1058
} else if (auto aggregate = dynamic_cast <hdk::ir::Aggregate*>(node)) {
1068
1059
auto src_it = liveout_renumbering.find (node->getInput (0 ));
1069
1060
CHECK (src_it != liveout_renumbering.end ());
@@ -1074,12 +1065,6 @@ void propagate_input_renumbering(
1074
1065
new_exprs.emplace_back (visitor.visit (expr.get ()));
1075
1066
}
1076
1067
aggregate->setAggExprs (std::move (new_exprs));
1077
- } else if (auto join = dynamic_cast <hdk::ir::Join*>(node)) {
1078
- auto new_condition = visitor.visit (join->getCondition ());
1079
- join->setCondition (std::move (new_condition));
1080
- } else if (auto filter = dynamic_cast <hdk::ir::Filter*>(node)) {
1081
- auto new_condition_expr = visitor.visit (filter->getConditionExpr ());
1082
- filter->setCondition (new_condition_expr);
1083
1068
} else if (auto sort = dynamic_cast <hdk::ir::Sort*>(node)) {
1084
1069
auto src_it = liveout_renumbering.find (node->getInput (0 ));
1085
1070
CHECK (src_it != liveout_renumbering.end ());
@@ -1329,8 +1314,7 @@ void sink_projected_boolean_expr_to_join(
1329
1314
project->setExpressions (std::move (new_proj_exprs));
1330
1315
1331
1316
ConditionReplacer replacer (in_idx_to_new_subcond);
1332
- auto new_condition = replacer.visit (join->getCondition ());
1333
- join->setCondition (std::move (new_condition));
1317
+ join->rewriteExprs (replacer);
1334
1318
}
1335
1319
}
1336
1320
0 commit comments