6
6
*/
7
7
8
8
#include " Node.h"
9
- #include " ExprRewriter .h"
9
+ #include " InputRewriter .h"
10
10
11
11
namespace hdk ::ir {
12
12
13
13
namespace {
14
14
15
15
const unsigned FIRST_NODE_ID = 1 ;
16
16
17
- class RebindInputsVisitor : public ExprRewriter {
18
- public:
19
- RebindInputsVisitor (const Node* old_input, const Node* new_input)
20
- : old_input_(old_input), new_input_(new_input) {}
21
-
22
- ExprPtr visitColumnRef (const ColumnRef* col_ref) override {
23
- if (col_ref->node () == old_input_) {
24
- return makeExpr<ColumnRef>(col_ref->type (), new_input_, col_ref->index ());
25
- }
26
- return ExprRewriter::visitColumnRef (col_ref);
27
- }
28
-
29
- protected:
30
- const Node* old_input_;
31
- const Node* new_input_;
32
- };
33
-
34
- class RebindReindexInputsVisitor : public RebindInputsVisitor {
35
- public:
36
- RebindReindexInputsVisitor (
37
- const Node* old_input,
38
- const Node* new_input,
39
- const std::optional<std::unordered_map<unsigned , unsigned >>& old_to_new_index_map)
40
- : RebindInputsVisitor(old_input, new_input), mapping_(old_to_new_index_map) {}
41
-
42
- ExprPtr visitColumnRef (const ColumnRef* col_ref) override {
43
- auto res = RebindInputsVisitor::visitColumnRef (col_ref);
44
- if (mapping_) {
45
- auto new_col_ref = dynamic_cast <const ColumnRef*>(res.get ());
46
- CHECK (new_col_ref);
47
- auto it = mapping_->find (new_col_ref->index ());
48
- CHECK (it != mapping_->end ());
49
- return makeExpr<ColumnRef>(new_col_ref->type (), new_col_ref->node (), it->second );
50
- }
51
- return res;
52
- }
53
-
54
- protected:
55
- const std::optional<std::unordered_map<unsigned , unsigned >>& mapping_;
56
- };
57
-
58
17
std::set<std::pair<const hdk::ir::Node*, int >> getEquivCols (const hdk::ir::Node* node,
59
18
const size_t which_col) {
60
19
std::set<std::pair<const hdk::ir::Node*, int >> work_set;
@@ -149,6 +108,34 @@ Node::Node(NodeInputs inputs)
149
108
, context_data_(nullptr )
150
109
, is_nop_(false ) {}
151
110
111
+ void Node::replaceInput (
112
+ std::shared_ptr<const Node> old_input,
113
+ std::shared_ptr<const Node> input,
114
+ std::optional<std::unordered_map<unsigned , unsigned >> old_to_new_index_map) {
115
+ InputRewriter rewriter;
116
+ if (old_to_new_index_map) {
117
+ rewriter.addNodeMapping (old_input.get (), input.get (), *old_to_new_index_map);
118
+ } else {
119
+ rewriter.addNodeMapping (old_input.get (), input.get ());
120
+ }
121
+ replaceInput (old_input, input, rewriter);
122
+ }
123
+
124
+ void Node::replaceInput (std::shared_ptr<const Node> old_input,
125
+ std::shared_ptr<const Node> input,
126
+ hdk::ir::ExprRewriter& input_redirector) {
127
+ bool replaced = false ;
128
+ for (auto & input_ptr : inputs_) {
129
+ if (input_ptr == old_input) {
130
+ input_ptr = input;
131
+ replaced = true ;
132
+ }
133
+ }
134
+ if (replaced) {
135
+ rewriteExprs (input_redirector);
136
+ }
137
+ }
138
+
152
139
void Node::resetRelAlgFirstId () noexcept {
153
140
crt_id_ = FIRST_NODE_ID;
154
141
}
@@ -160,14 +147,9 @@ void Node::print() const {
160
147
Project::Project (Project const & rhs)
161
148
: Node(rhs), exprs_(rhs.exprs_), fields_(rhs.fields_) {}
162
149
163
- void Project::replaceInput (
164
- std::shared_ptr<const Node> old_input,
165
- std::shared_ptr<const Node> input,
166
- std::optional<std::unordered_map<unsigned , unsigned >> old_to_new_index_map) {
167
- Node::replaceInput (old_input, input);
168
- RebindReindexInputsVisitor visitor (old_input.get (), input.get (), old_to_new_index_map);
150
+ void Project::rewriteExprs (hdk::ir::ExprRewriter& rewriter) {
169
151
for (size_t i = 0 ; i < exprs_.size (); ++i) {
170
- exprs_[i] = visitor .visit (exprs_[i].get ());
152
+ exprs_[i] = rewriter .visit (exprs_[i].get ());
171
153
}
172
154
}
173
155
@@ -225,34 +207,25 @@ Aggregate::Aggregate(Aggregate const& rhs)
225
207
, aggs_(rhs.aggs_)
226
208
, fields_(rhs.fields_) {}
227
209
228
- void Aggregate::replaceInput (std::shared_ptr<const Node> old_input,
229
- std::shared_ptr<const Node> input) {
230
- Node::replaceInput (old_input, input);
231
- RebindInputsVisitor visitor (old_input.get (), input.get ());
210
+ void Aggregate::rewriteExprs (hdk::ir::ExprRewriter& rewriter) {
232
211
for (size_t i = 0 ; i < aggs_.size (); ++i) {
233
- aggs_[i] = visitor .visit (aggs_[i].get ());
212
+ aggs_[i] = rewriter .visit (aggs_[i].get ());
234
213
}
235
214
}
236
215
237
216
Join::Join (Join const & rhs)
238
217
: Node(rhs), condition_(rhs.condition_), join_type_(rhs.join_type_) {}
239
218
240
- void Join::replaceInput (std::shared_ptr<const Node> old_input,
241
- std::shared_ptr<const Node> input) {
242
- Node::replaceInput (old_input, input);
219
+ void Join::rewriteExprs (hdk::ir::ExprRewriter& rewriter) {
243
220
if (condition_) {
244
- RebindInputsVisitor visitor (old_input.get (), input.get ());
245
- condition_ = visitor.visit (condition_.get ());
221
+ condition_ = rewriter.visit (condition_.get ());
246
222
}
247
223
}
248
224
249
225
Filter::Filter (Filter const & rhs) : Node(rhs), condition_(rhs.condition_) {}
250
226
251
- void Filter::replaceInput (std::shared_ptr<const Node> old_input,
252
- std::shared_ptr<const Node> input) {
253
- Node::replaceInput (old_input, input);
254
- RebindInputsVisitor visitor (old_input.get (), input.get ());
255
- condition_ = visitor.visit (condition_.get ());
227
+ void Filter::rewriteExprs (hdk::ir::ExprRewriter& rewriter) {
228
+ condition_ = rewriter.visit (condition_.get ());
256
229
}
257
230
258
231
bool Sort::hasEquivCollationOf (const Sort& that) const {
0 commit comments