Skip to content

Commit 25123a3

Browse files
committed
add tests
test=develop
1 parent 8c11d3f commit 25123a3

File tree

3 files changed

+102
-1
lines changed

3 files changed

+102
-1
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
5353

5454
cc_library(pass_builder SRCS pass_builder.cc DEPS pass)
5555

56+
cc_test(node_test SRCS node_test.cc DEPS node)
5657
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
5758
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
5859
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)

paddle/fluid/framework/ir/node.h

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,24 @@ namespace paddle {
2727
namespace framework {
2828
namespace ir {
2929

30-
// Node should normally created by Graph::CreateXXXNode().
30+
// Node should only created by Graph::CreateXXXNode().
31+
// 1. Every Node should be part of a graph. No dangling Node exists.
32+
// 2. Node only contains members necessary for building graph structure.
33+
// It doesn't contain other unrelated members, such as device, etc.
34+
//
35+
// Sometimes, for specific usages, Node needs to have additional members,
36+
// such as device_placement, version in order to be executed. It is suggested
37+
// to use composition pattern.
38+
//
39+
// class RunnableOp {
40+
// RunnableOp(ir::Node* n) : n_(n) { n_.WrappedBy(this); }
41+
//
42+
// int any_thing_;
43+
// }
44+
//
45+
// RunnableOp is owned by the ir::Node that composes it. In other words.
46+
// ir::Node will be responsible for deleting RunnableOp, say, when ir::Node
47+
// is deleted from the graph.
3148
class Node {
3249
public:
3350
virtual ~Node() {
@@ -53,6 +70,7 @@ class Node {
5370
return op_desc_.get();
5471
}
5572

73+
// Set the `wrapper` that wraps the Node. `wrapper` is owned by Node.
5674
template <typename T>
5775
void WrappedBy(T* wrapper) {
5876
if (!wrapper_.empty()) {
@@ -63,11 +81,13 @@ class Node {
6381
wrapper_type_ = std::type_index(typeid(T));
6482
}
6583

84+
// Return a reference to the `wrapper`.
6685
template <typename T>
6786
T& Wrapper() {
6887
return *boost::any_cast<T*>(wrapper_);
6988
}
7089

90+
// Test if the Node is wrapped by type T.
7191
template <typename T>
7292
bool IsWrappedBy() {
7393
return std::type_index(typeid(T)) == wrapper_type_;
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <string>
16+
#include "gtest/gtest.h"
17+
#include "paddle/fluid/framework/ir/graph.h"
18+
#include "paddle/fluid/framework/ir/pass.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace ir {
23+
24+
class RunnableOp {
25+
public:
26+
RunnableOp(Node* node, bool* alive) : node_(node), alive_(alive) {
27+
node_->WrappedBy(this);
28+
}
29+
30+
virtual ~RunnableOp() { *alive_ = false; }
31+
32+
private:
33+
Node* node_;
34+
bool* alive_;
35+
};
36+
37+
class RunnableOp2 {
38+
public:
39+
RunnableOp2(Node* node, bool* alive) : node_(node), alive_(alive) {
40+
node_->WrappedBy(this);
41+
}
42+
43+
virtual ~RunnableOp2() { *alive_ = false; }
44+
45+
private:
46+
Node* node_;
47+
bool* alive_;
48+
};
49+
50+
TEST(NodeTest, Basic) {
51+
bool alive1 = true;
52+
bool alive2 = true;
53+
std::unique_ptr<Node> n1(CreateNodeForTest("n1", Node::Type::kVariable));
54+
std::unique_ptr<Node> n2(CreateNodeForTest("n2", Node::Type::kVariable));
55+
56+
EXPECT_FALSE(n1->IsWrappedBy<RunnableOp>());
57+
EXPECT_FALSE(n1->IsWrappedBy<RunnableOp2>());
58+
EXPECT_FALSE(n2->IsWrappedBy<RunnableOp>());
59+
EXPECT_FALSE(n2->IsWrappedBy<RunnableOp2>());
60+
61+
new RunnableOp(n1.get(), &alive1);
62+
new RunnableOp2(n2.get(), &alive2);
63+
64+
EXPECT_TRUE(n1->IsWrappedBy<RunnableOp>());
65+
EXPECT_FALSE(n1->IsWrappedBy<RunnableOp2>());
66+
EXPECT_FALSE(n2->IsWrappedBy<RunnableOp>());
67+
EXPECT_TRUE(n2->IsWrappedBy<RunnableOp2>());
68+
69+
EXPECT_TRUE(alive1);
70+
EXPECT_TRUE(alive2);
71+
72+
n1.reset(nullptr);
73+
n2.reset(nullptr);
74+
EXPECT_FALSE(alive1);
75+
EXPECT_FALSE(alive2);
76+
}
77+
78+
} // namespace ir
79+
} // namespace framework
80+
} // namespace paddle

0 commit comments

Comments
 (0)