Skip to content

Commit 62e2aa1

Browse files
committed
add a graph_test
1 parent a323b26 commit 62e2aa1

File tree

2 files changed

+116
-2
lines changed

2 files changed

+116
-2
lines changed
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1-
cc_library(graph SRCS graph.cc)
1+
cc_library(graph SRCS graph.cc node)
22
cc_library(node SRCS node.cc)
3-
cc_library(pass SRCS pass.cc)
3+
cc_library(pass SRCS pass.cc graph node)
4+
5+
cc_test(graph_test SRCS graph_test.cc DEPS graph proto_desc op_registry)
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/ir/graph.h"
16+
#include "gtest/gtest.h"
17+
#include "paddle/fluid/framework/op_registry.h"
18+
#include "paddle/fluid/framework/operator.h"
19+
#include "paddle/fluid/framework/program_desc.h"
20+
21+
namespace paddle {
22+
namespace framework {
23+
24+
class NOP : public OperatorBase {
25+
public:
26+
NOP(const std::string &type, const VariableNameMap &inputs,
27+
const VariableNameMap &outputs, const AttributeMap &attrs)
28+
: OperatorBase(type, inputs, outputs, attrs) {}
29+
30+
private:
31+
void RunImpl(const Scope &scope,
32+
const platform::Place &place) const override {}
33+
};
34+
35+
class SumOpMaker : public OpProtoAndCheckerMaker {
36+
public:
37+
void Make() {
38+
AddInput("X", "").AsDuplicable();
39+
AddOutput("Out", "");
40+
AddComment("");
41+
}
42+
};
43+
44+
class SumOpVarTypeInference : public VarTypeInference {
45+
public:
46+
void operator()(const OpDesc &op_desc, BlockDesc *block) const override {
47+
auto &inputs = op_desc.Input("X");
48+
auto default_var_type = proto::VarType::SELECTED_ROWS;
49+
50+
bool any_input_is_lod_tensor = std::any_of(
51+
inputs.begin(), inputs.end(), [block](const std::string &name) {
52+
return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR;
53+
});
54+
if (any_input_is_lod_tensor) {
55+
default_var_type = proto::VarType::LOD_TENSOR;
56+
}
57+
58+
auto out_var_name = op_desc.Output("Out").front();
59+
block->Var(out_var_name)->SetType(default_var_type);
60+
}
61+
};
62+
} // namespace framework
63+
} // namespace paddle
64+
65+
REGISTER_OPERATOR(sum, paddle::framework::NOP, paddle::framework::SumOpMaker,
66+
paddle::framework::SumOpVarTypeInference);
67+
REGISTER_OPERATOR(sum_without_infer_var_type, paddle::framework::NOP,
68+
paddle::framework::SumOpMaker);
69+
70+
namespace paddle {
71+
namespace framework {
72+
73+
TEST(GraphTest, Basic) {
74+
ProgramDesc prog;
75+
auto *op = prog.MutableBlock(0)->AppendOp();
76+
op->SetType("sum");
77+
op->SetInput("X", {"test_a", "test_b", "test_c"});
78+
op->SetOutput("Out", {"test_out"});
79+
80+
prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS);
81+
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS);
82+
prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS);
83+
prog.MutableBlock(0)->Var("test_out");
84+
85+
op->InferVarType(prog.MutableBlock(0));
86+
87+
ASSERT_EQ(proto::VarType::SELECTED_ROWS,
88+
prog.MutableBlock(0)->Var("test_out")->GetType());
89+
90+
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::LOD_TENSOR);
91+
op->InferVarType(prog.MutableBlock(0));
92+
ASSERT_EQ(proto::VarType::LOD_TENSOR,
93+
prog.MutableBlock(0)->Var("test_out")->GetType());
94+
95+
std::unique_ptr<Graph> g(new Graph(prog));
96+
ASSERT_EQ(g->nodes[0]->Name(), "sum");
97+
ASSERT_EQ(g->nodes[0]->inputs[0]->Name(), "test_a");
98+
ASSERT_EQ(g->nodes[0]->inputs[1]->Name(), "test_b");
99+
ASSERT_EQ(g->nodes[0]->inputs[2]->Name(), "test_c");
100+
ASSERT_EQ(g->nodes[0]->outputs[0]->Name(), "test_out");
101+
ASSERT_EQ(g->nodes[1]->Name(), "test_a");
102+
ASSERT_EQ(g->nodes[1]->outputs[0]->Name(), "sum");
103+
ASSERT_EQ(g->nodes[2]->Name(), "test_b");
104+
ASSERT_EQ(g->nodes[2]->outputs[0]->Name(), "sum");
105+
ASSERT_EQ(g->nodes[3]->Name(), "test_c");
106+
ASSERT_EQ(g->nodes[3]->outputs[0]->Name(), "sum");
107+
ASSERT_EQ(g->nodes[4]->Name(), "test_out");
108+
ASSERT_EQ(g->nodes[4]->inputs[0]->Name(), "sum");
109+
ASSERT_EQ(g->nodes.size(), 5);
110+
}
111+
} // namespace framework
112+
} // namespace paddle

0 commit comments

Comments
 (0)