@@ -36,7 +36,7 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
36
36
public:
37
37
void Make () {
38
38
AddInput (" X" , " " ).AsDuplicable ();
39
- AddOutput (" Out" , " " );
39
+ AddOutput (" Out" , " " ). AsDuplicable () ;
40
40
AddComment (" " );
41
41
}
42
42
};
@@ -59,11 +59,27 @@ class SumOpVarTypeInference : public VarTypeInference {
59
59
block->Var (out_var_name)->SetType (default_var_type);
60
60
}
61
61
};
62
+
63
+ class DummyOpMaker : public OpProtoAndCheckerMaker {
64
+ public:
65
+ void Make () {
66
+ AddInput (" X" , " " ).AsDuplicable ();
67
+ AddOutput (" Out" , " " ).AsDuplicable ();
68
+ AddComment (" " );
69
+ }
70
+ };
71
+
72
+ class DummyOpVarTypeInference : public VarTypeInference {
73
+ public:
74
+ void operator ()(const OpDesc &op_desc, BlockDesc *block) const override {}
75
+ };
62
76
} // namespace framework
63
77
} // namespace paddle
64
78
65
79
REGISTER_OPERATOR (sum, paddle::framework::NOP, paddle::framework::SumOpMaker,
66
80
paddle::framework::SumOpVarTypeInference);
81
+ REGISTER_OPERATOR (dummy, paddle::framework::NOP, paddle::framework::SumOpMaker,
82
+ paddle::framework::SumOpVarTypeInference);
67
83
REGISTER_OPERATOR (sum_without_infer_var_type, paddle::framework::NOP,
68
84
paddle::framework::SumOpMaker);
69
85
@@ -110,5 +126,83 @@ TEST(GraphTest, Basic) {
110
126
}
111
127
ASSERT_EQ (nodes.size (), 5 );
112
128
}
129
+
130
+ TEST (GraphTest, WriteAfterRead) {
131
+ // void Test() {
132
+ ProgramDesc prog;
133
+ auto *op = prog.MutableBlock (0 )->AppendOp ();
134
+ op->SetType (" sum" );
135
+ op->SetInput (" X" , {" a" });
136
+ op->SetOutput (" Out" , {" b" });
137
+ op->SetAttr (" op_role" , 1 );
138
+
139
+ op = prog.MutableBlock (0 )->AppendOp ();
140
+ op->SetType (" dummy" );
141
+ op->SetInput (" X" , {" c" });
142
+ op->SetOutput (" Out" , {" a" });
143
+ op->SetAttr (" op_role" , 1 );
144
+
145
+ prog.MutableBlock (0 )->Var (" a" )->SetType (proto::VarType::LOD_TENSOR);
146
+ prog.MutableBlock (0 )->Var (" b" )->SetType (proto::VarType::LOD_TENSOR);
147
+ prog.MutableBlock (0 )->Var (" c" )->SetType (proto::VarType::LOD_TENSOR);
148
+
149
+ std::unique_ptr<ir::Graph> g (new ir::Graph (prog));
150
+ ir::Node *control_dep1 = nullptr ;
151
+ ir::Node *control_dep2 = nullptr ;
152
+ for (ir::Node *n : g->Nodes ()) {
153
+ if (n->Name () == " sum" ) {
154
+ ASSERT_EQ (n->outputs [0 ]->Name (), " b" );
155
+ ASSERT_TRUE (ir::IsControlDepVar (*n->outputs [1 ]));
156
+ control_dep1 = n->outputs [1 ];
157
+ ASSERT_EQ (n->outputs .size (), 2 );
158
+ }
159
+ if (n->Name () == " dummy" ) {
160
+ ASSERT_EQ (n->inputs [0 ]->Name (), " c" );
161
+ ASSERT_TRUE (ir::IsControlDepVar (*n->inputs [1 ]));
162
+ control_dep2 = n->inputs [1 ];
163
+ ASSERT_EQ (n->inputs .size (), 2 );
164
+ }
165
+ }
166
+ ASSERT_EQ (control_dep1, control_dep2);
167
+ }
168
+
169
+ TEST (GraphTest, WriteAfterWrite) {
170
+ // void Test() {
171
+ ProgramDesc prog;
172
+ auto *op = prog.MutableBlock (0 )->AppendOp ();
173
+ op->SetType (" sum" );
174
+ op->SetInput (" X" , {" a" });
175
+ op->SetOutput (" Out" , {" b" });
176
+ op->SetAttr (" op_role" , 1 );
177
+
178
+ op = prog.MutableBlock (0 )->AppendOp ();
179
+ op->SetType (" dummy" );
180
+ op->SetInput (" X" , {" c" });
181
+ op->SetOutput (" Out" , {" b" });
182
+ op->SetAttr (" op_role" , 1 );
183
+
184
+ prog.MutableBlock (0 )->Var (" a" )->SetType (proto::VarType::LOD_TENSOR);
185
+ prog.MutableBlock (0 )->Var (" b" )->SetType (proto::VarType::LOD_TENSOR);
186
+ prog.MutableBlock (0 )->Var (" c" )->SetType (proto::VarType::LOD_TENSOR);
187
+
188
+ std::unique_ptr<ir::Graph> g (new ir::Graph (prog));
189
+ ir::Node *control_dep1 = nullptr ;
190
+ ir::Node *control_dep2 = nullptr ;
191
+ for (ir::Node *n : g->Nodes ()) {
192
+ if (n->Name () == " sum" ) {
193
+ ASSERT_EQ (n->outputs [0 ]->Name (), " b" );
194
+ ASSERT_TRUE (ir::IsControlDepVar (*n->outputs [1 ]));
195
+ ASSERT_EQ (n->outputs .size (), 2 );
196
+ control_dep1 = n->outputs [1 ];
197
+ }
198
+ if (n->Name () == " dummy" ) {
199
+ ASSERT_EQ (n->inputs [0 ]->Name (), " c" );
200
+ ASSERT_TRUE (ir::IsControlDepVar (*n->inputs [1 ]));
201
+ control_dep2 = n->inputs [1 ];
202
+ ASSERT_EQ (n->inputs .size (), 2 );
203
+ ASSERT_EQ (control_dep1, control_dep2);
204
+ }
205
+ }
206
+ }
113
207
} // namespace framework
114
208
} // namespace paddle
0 commit comments