@@ -15,6 +15,7 @@ limitations under the License. */
15
15
#pragma once
16
16
17
17
#include < cstdint>
18
+ #include < functional>
18
19
#include < map>
19
20
#include < string>
20
21
#include < vector>
@@ -23,13 +24,23 @@ limitations under the License. */
23
24
24
25
namespace paddle {
25
26
namespace framework {
27
+ namespace ir {
26
28
27
29
class Node {
28
30
public:
29
31
enum class Type { kNone = -1 , kOperation , kVariable };
30
32
31
- Node () {}
32
- virtual ~Node () {}
33
+ Node (const std::string& name, Type type) : name_(name), type_(type) {}
34
+
35
+ virtual ~Node () {
36
+ for (auto & attr : attrs_) {
37
+ if (attr_dels_.find (attr.first ) != attr_dels_.end ()) {
38
+ attr_dels_[attr.first ]();
39
+ }
40
+ }
41
+ attr_dels_.clear ();
42
+ attrs_.clear ();
43
+ }
33
44
34
45
int64_t ID () const { return id_; }
35
46
@@ -43,17 +54,42 @@ class Node {
43
54
44
55
Type NodeType () const { return type_; }
45
56
46
- std::vector<Node *> inputs;
47
- std::vector<Node *> outputs;
57
+ template <typename AttrType>
58
+ void Set (const std::string& name, AttrType attr) {
59
+ attrs_[name] = attr;
60
+ }
61
+
62
+ template <typename AttrType>
63
+ void Set (const std::string& name, AttrType* attr,
64
+ std::function<void (void )> attr_del) {
65
+ attrs_[name] = attr;
66
+ attr_dels_[name] = attr_del;
67
+ }
68
+
69
+ std::vector<Node*> inputs;
70
+ std::vector<Node*> outputs;
48
71
49
72
protected:
50
- std::map<std::string, std::vector<boost::any>> attrs_;
73
+ std::map<std::string, boost::any> attrs_;
74
+ std::map<std::string, std::function<void (void )>> attr_dels_;
51
75
int64_t id_ = 0 ;
52
76
std::string name_;
53
77
Type type_;
54
78
79
+ private:
55
80
DISABLE_COPY_AND_ASSIGN (Node);
56
81
};
57
82
83
+ class Variable : public Node {
84
+ public:
85
+ explicit Variable (const std::string& name) : Node(name, Type::kVariable ) {}
86
+ };
87
+
88
+ class Operation : public Node {
89
+ public:
90
+ explicit Operation (const std::string& name) : Node(name, Type::kOperation ) {}
91
+ };
92
+
93
+ } // namespace ir
58
94
} // namespace framework
59
95
} // namespace paddle
0 commit comments