Skip to content

Commit 7231ef6

Browse files
committed
tmp
1 parent 68aa500 commit 7231ef6

File tree

2 files changed

+44
-8
lines changed

2 files changed

+44
-8
lines changed

paddle/fluid/framework/ir/graph.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ class Graph {
5858
return attr;
5959
}
6060

61-
std::vector<Node*> inputs;
62-
std::vector<Node*> outputs;
63-
std::vector<std::unique_ptr<Node>> nodes;
61+
std::vector<ir::Node*> inputs;
62+
std::vector<ir::Node*> outputs;
63+
std::vector<std::unique_ptr<ir::Node>> nodes;
6464

6565
private:
6666
std::map<std::string, boost::any> attrs_;

paddle/fluid/framework/ir/node.h

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#pragma once
1616

1717
#include <cstdint>
18+
#include <functional>
1819
#include <map>
1920
#include <string>
2021
#include <vector>
@@ -23,13 +24,23 @@ limitations under the License. */
2324

2425
namespace paddle {
2526
namespace framework {
27+
namespace ir {
2628

2729
class Node {
2830
public:
2931
enum class Type { kNone = -1, kOperation, kVariable };
3032

31-
Node() {}
32-
virtual ~Node() {}
33+
Node(const std::string& name, Type type) : name_(name), type_(type) {}
34+
35+
virtual ~Node() {
36+
for (auto& attr : attrs_) {
37+
if (attr_dels_.find(attr.first) != attr_dels_.end()) {
38+
attr_dels_[attr.first]();
39+
}
40+
}
41+
attr_dels_.clear();
42+
attrs_.clear();
43+
}
3344

3445
int64_t ID() const { return id_; }
3546

@@ -43,17 +54,42 @@ class Node {
4354

4455
Type NodeType() const { return type_; }
4556

46-
std::vector<Node *> inputs;
47-
std::vector<Node *> outputs;
57+
template <typename AttrType>
58+
void Set(const std::string& name, AttrType attr) {
59+
attrs_[name] = attr;
60+
}
61+
62+
template <typename AttrType>
63+
void Set(const std::string& name, AttrType* attr,
64+
std::function<void(void)> attr_del) {
65+
attrs_[name] = attr;
66+
attr_dels_[name] = attr_del;
67+
}
68+
69+
std::vector<Node*> inputs;
70+
std::vector<Node*> outputs;
4871

4972
protected:
50-
std::map<std::string, std::vector<boost::any>> attrs_;
73+
std::map<std::string, boost::any> attrs_;
74+
std::map<std::string, std::function<void(void)>> attr_dels_;
5175
int64_t id_ = 0;
5276
std::string name_;
5377
Type type_;
5478

79+
private:
5580
DISABLE_COPY_AND_ASSIGN(Node);
5681
};
5782

83+
class Variable : public Node {
84+
public:
85+
explicit Variable(const std::string& name) : Node(name, Type::kVariable) {}
86+
};
87+
88+
class Operation : public Node {
89+
public:
90+
explicit Operation(const std::string& name) : Node(name, Type::kOperation) {}
91+
};
92+
93+
} // namespace ir
5894
} // namespace framework
5995
} // namespace paddle

0 commit comments

Comments
 (0)