Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit f4a3dea

Browse files
committed
Refactor Node::replaceInput.
Signed-off-by: ienkovich <[email protected]>
1 parent b1e88e6 commit f4a3dea

File tree

4 files changed

+120
-104
lines changed

4 files changed

+120
-104
lines changed

omniscidb/IR/InputRewriter.h

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright (C) 2023 Intel Corporation
3+
*
4+
* SPDX-License-Identifier: Apache-2.0
5+
*/
6+
7+
#pragma once
8+
9+
#include "IR/ExprRewriter.h"
10+
11+
#include <boost/functional/hash.hpp>
12+
13+
#include <unordered_map>
14+
15+
namespace hdk::ir {
16+
17+
class InputRewriter final : public ExprRewriter {
18+
public:
19+
InputRewriter() = default;
20+
21+
InputRewriter(const Node* old_input, const Node* new_input) {
22+
addNodeMapping(old_input, new_input);
23+
}
24+
25+
InputRewriter(const Node* old_input,
26+
const Node* new_input,
27+
const std::unordered_map<unsigned, unsigned>& old_to_new_index_map) {
28+
addNodeMapping(old_input, new_input, old_to_new_index_map);
29+
}
30+
31+
void addNodeMapping(const Node* old_input, const Node* new_input) {
32+
node_map_[old_input] = new_input;
33+
}
34+
35+
void addNodeMapping(
36+
const Node* old_input,
37+
const Node* new_input,
38+
const std::unordered_map<unsigned, unsigned>& old_to_new_index_map) {
39+
node_map_[old_input] = new_input;
40+
for (auto& pr : old_to_new_index_map) {
41+
index_map_[std::make_pair(old_input, pr.first)] = pr.second;
42+
}
43+
}
44+
45+
ExprPtr visitColumnRef(const ColumnRef* col_ref) override {
46+
auto node_it = node_map_.find(col_ref->node());
47+
if (node_it != node_map_.end()) {
48+
unsigned index = col_ref->index();
49+
auto idx_it = index_map_.find(std::make_pair(col_ref->node(), index));
50+
if (idx_it != index_map_.end()) {
51+
index = idx_it->second;
52+
}
53+
return makeExpr<ColumnRef>(col_ref->type(), node_it->second, index);
54+
}
55+
return ExprRewriter::visitColumnRef(col_ref);
56+
}
57+
58+
protected:
59+
std::unordered_map<const Node*, const Node*> node_map_;
60+
std::unordered_map<std::pair<const Node*, unsigned>,
61+
unsigned,
62+
boost::hash<std::pair<const Node*, unsigned>>>
63+
index_map_;
64+
};
65+
66+
} // namespace hdk::ir

omniscidb/IR/Node.cpp

Lines changed: 37 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -6,55 +6,14 @@
66
*/
77

88
#include "Node.h"
9-
#include "ExprRewriter.h"
9+
#include "InputRewriter.h"
1010

1111
namespace hdk::ir {
1212

1313
namespace {
1414

1515
const unsigned FIRST_NODE_ID = 1;
1616

17-
class RebindInputsVisitor : public ExprRewriter {
18-
public:
19-
RebindInputsVisitor(const Node* old_input, const Node* new_input)
20-
: old_input_(old_input), new_input_(new_input) {}
21-
22-
ExprPtr visitColumnRef(const ColumnRef* col_ref) override {
23-
if (col_ref->node() == old_input_) {
24-
return makeExpr<ColumnRef>(col_ref->type(), new_input_, col_ref->index());
25-
}
26-
return ExprRewriter::visitColumnRef(col_ref);
27-
}
28-
29-
protected:
30-
const Node* old_input_;
31-
const Node* new_input_;
32-
};
33-
34-
class RebindReindexInputsVisitor : public RebindInputsVisitor {
35-
public:
36-
RebindReindexInputsVisitor(
37-
const Node* old_input,
38-
const Node* new_input,
39-
const std::optional<std::unordered_map<unsigned, unsigned>>& old_to_new_index_map)
40-
: RebindInputsVisitor(old_input, new_input), mapping_(old_to_new_index_map) {}
41-
42-
ExprPtr visitColumnRef(const ColumnRef* col_ref) override {
43-
auto res = RebindInputsVisitor::visitColumnRef(col_ref);
44-
if (mapping_) {
45-
auto new_col_ref = dynamic_cast<const ColumnRef*>(res.get());
46-
CHECK(new_col_ref);
47-
auto it = mapping_->find(new_col_ref->index());
48-
CHECK(it != mapping_->end());
49-
return makeExpr<ColumnRef>(new_col_ref->type(), new_col_ref->node(), it->second);
50-
}
51-
return res;
52-
}
53-
54-
protected:
55-
const std::optional<std::unordered_map<unsigned, unsigned>>& mapping_;
56-
};
57-
5817
std::set<std::pair<const hdk::ir::Node*, int>> getEquivCols(const hdk::ir::Node* node,
5918
const size_t which_col) {
6019
std::set<std::pair<const hdk::ir::Node*, int>> work_set;
@@ -149,6 +108,34 @@ Node::Node(NodeInputs inputs)
149108
, context_data_(nullptr)
150109
, is_nop_(false) {}
151110

111+
void Node::replaceInput(
112+
std::shared_ptr<const Node> old_input,
113+
std::shared_ptr<const Node> input,
114+
std::optional<std::unordered_map<unsigned, unsigned>> old_to_new_index_map) {
115+
InputRewriter rewriter;
116+
if (old_to_new_index_map) {
117+
rewriter.addNodeMapping(old_input.get(), input.get(), *old_to_new_index_map);
118+
} else {
119+
rewriter.addNodeMapping(old_input.get(), input.get());
120+
}
121+
replaceInput(old_input, input, rewriter);
122+
}
123+
124+
void Node::replaceInput(std::shared_ptr<const Node> old_input,
125+
std::shared_ptr<const Node> input,
126+
hdk::ir::ExprRewriter& input_redirector) {
127+
bool replaced = false;
128+
for (auto& input_ptr : inputs_) {
129+
if (input_ptr == old_input) {
130+
input_ptr = input;
131+
replaced = true;
132+
}
133+
}
134+
if (replaced) {
135+
rewriteExprs(input_redirector);
136+
}
137+
}
138+
152139
void Node::resetRelAlgFirstId() noexcept {
153140
crt_id_ = FIRST_NODE_ID;
154141
}
@@ -160,14 +147,9 @@ void Node::print() const {
160147
Project::Project(Project const& rhs)
161148
: Node(rhs), exprs_(rhs.exprs_), fields_(rhs.fields_) {}
162149

163-
void Project::replaceInput(
164-
std::shared_ptr<const Node> old_input,
165-
std::shared_ptr<const Node> input,
166-
std::optional<std::unordered_map<unsigned, unsigned>> old_to_new_index_map) {
167-
Node::replaceInput(old_input, input);
168-
RebindReindexInputsVisitor visitor(old_input.get(), input.get(), old_to_new_index_map);
150+
void Project::rewriteExprs(hdk::ir::ExprRewriter& rewriter) {
169151
for (size_t i = 0; i < exprs_.size(); ++i) {
170-
exprs_[i] = visitor.visit(exprs_[i].get());
152+
exprs_[i] = rewriter.visit(exprs_[i].get());
171153
}
172154
}
173155

@@ -225,34 +207,25 @@ Aggregate::Aggregate(Aggregate const& rhs)
225207
, aggs_(rhs.aggs_)
226208
, fields_(rhs.fields_) {}
227209

228-
void Aggregate::replaceInput(std::shared_ptr<const Node> old_input,
229-
std::shared_ptr<const Node> input) {
230-
Node::replaceInput(old_input, input);
231-
RebindInputsVisitor visitor(old_input.get(), input.get());
210+
void Aggregate::rewriteExprs(hdk::ir::ExprRewriter& rewriter) {
232211
for (size_t i = 0; i < aggs_.size(); ++i) {
233-
aggs_[i] = visitor.visit(aggs_[i].get());
212+
aggs_[i] = rewriter.visit(aggs_[i].get());
234213
}
235214
}
236215

237216
Join::Join(Join const& rhs)
238217
: Node(rhs), condition_(rhs.condition_), join_type_(rhs.join_type_) {}
239218

240-
void Join::replaceInput(std::shared_ptr<const Node> old_input,
241-
std::shared_ptr<const Node> input) {
242-
Node::replaceInput(old_input, input);
219+
void Join::rewriteExprs(hdk::ir::ExprRewriter& rewriter) {
243220
if (condition_) {
244-
RebindInputsVisitor visitor(old_input.get(), input.get());
245-
condition_ = visitor.visit(condition_.get());
221+
condition_ = rewriter.visit(condition_.get());
246222
}
247223
}
248224

249225
Filter::Filter(Filter const& rhs) : Node(rhs), condition_(rhs.condition_) {}
250226

251-
void Filter::replaceInput(std::shared_ptr<const Node> old_input,
252-
std::shared_ptr<const Node> input) {
253-
Node::replaceInput(old_input, input);
254-
RebindInputsVisitor visitor(old_input.get(), input.get());
255-
condition_ = visitor.visit(condition_.get());
227+
void Filter::rewriteExprs(hdk::ir::ExprRewriter& rewriter) {
228+
condition_ = rewriter.visit(condition_.get());
256229
}
257230

258231
bool Sort::hasEquivCollationOf(const Sort& that) const {

omniscidb/IR/Node.h

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#pragma once
99

1010
#include "Expr.h"
11+
#include "ExprRewriter.h"
1112

1213
#include "QueryEngine/TargetMetaInfo.h"
1314
#include "SchemaMgr/TableInfo.h"
@@ -131,15 +132,15 @@ class Node {
131132
return false;
132133
}
133134

134-
virtual void replaceInput(std::shared_ptr<const Node> old_input,
135-
std::shared_ptr<const Node> input) {
136-
for (auto& input_ptr : inputs_) {
137-
if (input_ptr == old_input) {
138-
input_ptr = input;
139-
break;
140-
}
141-
}
142-
}
135+
virtual void rewriteExprs(hdk::ir::ExprRewriter& rewriter) {}
136+
137+
void replaceInput(std::shared_ptr<const Node> old_input,
138+
std::shared_ptr<const Node> input,
139+
std::optional<std::unordered_map<unsigned, unsigned>>
140+
old_to_new_index_map = std::nullopt);
141+
void replaceInput(std::shared_ptr<const Node> old_input,
142+
std::shared_ptr<const Node> input,
143+
hdk::ir::ExprRewriter& input_redirector);
143144

144145
// to keep an assigned DAG node id for data recycler
145146
void setRelNodeDagId(const size_t id) const { dag_node_id_ = id; }
@@ -323,15 +324,7 @@ class Project : public Node {
323324
return fields_[i];
324325
}
325326

326-
void replaceInput(std::shared_ptr<const Node> old_input,
327-
std::shared_ptr<const Node> input) override {
328-
replaceInput(old_input, input, std::nullopt);
329-
}
330-
331-
void replaceInput(
332-
std::shared_ptr<const Node> old_input,
333-
std::shared_ptr<const Node> input,
334-
std::optional<std::unordered_map<unsigned, unsigned>> old_to_new_index_map);
327+
void rewriteExprs(hdk::ir::ExprRewriter& rewriter) override;
335328

336329
void appendInput(std::string new_field_name, ExprPtr expr);
337330

@@ -403,8 +396,7 @@ class Aggregate : public Node {
403396

404397
void setAggExprs(ExprPtrVector new_aggs) { aggs_ = std::move(new_aggs); }
405398

406-
void replaceInput(std::shared_ptr<const Node> old_input,
407-
std::shared_ptr<const Node> input) override;
399+
void rewriteExprs(hdk::ir::ExprRewriter& rewriter) override;
408400

409401
std::string toString() const override {
410402
return cat(::typeName(this),
@@ -465,8 +457,7 @@ class Join : public Node {
465457

466458
void setCondition(ExprPtr new_condition) { condition_ = std::move(new_condition); }
467459

468-
void replaceInput(std::shared_ptr<const Node> old_input,
469-
std::shared_ptr<const Node> input) override;
460+
void rewriteExprs(hdk::ir::ExprRewriter& rewriter) override;
470461

471462
std::string toString() const override {
472463
return cat(::typeName(this),
@@ -583,10 +574,7 @@ class TranslatedJoin : public Node {
583574
std::string getOpTypeInfo() const { return op_typeinfo_; }
584575
size_t size() const override { return 0; }
585576
JoinType getJoinType() const { return join_type_; }
586-
void replaceInput(std::shared_ptr<const Node> old_input,
587-
std::shared_ptr<const Node> input) override {
588-
CHECK(false);
589-
}
577+
void rewriteExprs(hdk::ir::ExprRewriter& rewriter) override { CHECK(false); }
590578
std::shared_ptr<Node> deepCopy() const override {
591579
CHECK(false);
592580
return nullptr;
@@ -640,8 +628,7 @@ class Filter : public Node {
640628
return inputs_[0]->size();
641629
}
642630
643-
void replaceInput(std::shared_ptr<const Node> old_input,
644-
std::shared_ptr<const Node> input) override;
631+
void rewriteExprs(hdk::ir::ExprRewriter& rewriter) override;
645632
646633
std::string toString() const override {
647634
return cat(::typeName(this),
@@ -880,9 +867,6 @@ class QueryDag {
880867
return subqueries_;
881868
}
882869
883-
/**
884-
* Gets all registered subqueries. Only the root DAG can contain subqueries.
885-
*/
886870
void resetQueryExecutionState();
887871
888872
time_t now() const { return now_; }

omniscidb/QueryEngine/RelAlgOptimizer.cpp

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -318,13 +318,8 @@ void redirect_inputs_of(
318318
return;
319319
}
320320
if (auto project = std::dynamic_pointer_cast<hdk::ir::Project>(node)) {
321-
project->hdk::ir::Node::replaceInput(src_project, src_project->getAndOwnInput(0));
322321
ProjectInputRedirector visitor(projects);
323-
hdk::ir::ExprPtrVector new_exprs;
324-
for (auto& expr : project->getExprs()) {
325-
new_exprs.push_back(visitor.visit(expr.get()));
326-
}
327-
project->setExpressions(std::move(new_exprs));
322+
project->replaceInput(src_project, src_project->getAndOwnInput(0), visitor);
328323
return;
329324
}
330325
if (auto filter = std::dynamic_pointer_cast<hdk::ir::Filter>(node)) {
@@ -334,10 +329,8 @@ void redirect_inputs_of(
334329
if (is_permutating_proj) {
335330
propagate_rex_input_renumber(filter.get(), du_web);
336331
}
337-
filter->hdk::ir::Node::replaceInput(src_project, src_project->getAndOwnInput(0));
338332
ProjectInputRedirector visitor(projects);
339-
auto new_condition_expr = visitor.visit(filter->getConditionExpr());
340-
filter->setCondition(new_condition_expr);
333+
filter->replaceInput(src_project, src_project->getAndOwnInput(0), visitor);
341334
} else {
342335
filter->replaceInput(src_project, src_project->getAndOwnInput(0));
343336
}

0 commit comments

Comments
 (0)