Skip to content

Commit 08d22cf

Browse files
authored
Merge pull request #14091 from panyx0718/fix2
add program check
2 parents 91b2851 + a943134 commit 08d22cf

File tree

4 files changed

+63
-0
lines changed

4 files changed

+63
-0
lines changed

paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h"
1616

1717
#include <gtest/gtest.h>
18+
#include "paddle/fluid/framework/op_proto_maker.h"
1819

1920
namespace paddle {
2021
namespace framework {
@@ -36,6 +37,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
3637
op->SetInput("X", inputs);
3738
}
3839
op->SetOutput("Out", outputs);
40+
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
41+
static_cast<int>(OpRole::kForward));
3942
}
4043

4144
// a->OP0->b

paddle/fluid/framework/ir/fc_fuse_pass_tester.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
1616

1717
#include <gtest/gtest.h>
18+
#include "paddle/fluid/framework/op_proto_maker.h"
1819

1920
namespace paddle {
2021
namespace framework {
@@ -32,6 +33,8 @@ void SetOp(ProgramDesc* prog, const std::string& type,
3233
op->SetInput("X", inputs);
3334
}
3435
op->SetOutput("Out", outputs);
36+
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
37+
static_cast<int>(OpRole::kForward));
3538
}
3639

3740
// a->OP0->b

paddle/fluid/framework/ir/graph.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,62 @@ limitations under the License. */
2323
namespace paddle {
2424
namespace framework {
2525
namespace ir {
26+
namespace {
27+
28+
void CheckProgram(const ProgramDesc &program) {
29+
std::map<int, bool> visit;
30+
#define _INT(role) static_cast<int>(role)
31+
32+
for (size_t i = 0; i < program.Size(); ++i) {
33+
for (OpDesc *op : program.Block(i).AllOps()) {
34+
// For backward compatibility, some program doesn't have role added.
35+
if (!op->HasAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) continue;
36+
int role_id = boost::get<int>(
37+
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
38+
visit[role_id] = true;
39+
switch (role_id) {
40+
case _INT(OpRole::kForward):
41+
PADDLE_ENFORCE(
42+
visit.find(_INT(OpRole::kBackward)) == visit.end(),
43+
"Cannot add forward operator before backward operator.");
44+
break;
45+
case _INT(OpRole::kBackward):
46+
case _INT(OpRole::kBackward) | _INT(OpRole::kLoss):
47+
PADDLE_ENFORCE(
48+
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
49+
"Cannot add backward operator before optimize operator.");
50+
break;
51+
case _INT(OpRole::kForward) | _INT(OpRole::kLoss):
52+
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) |
53+
_INT(OpRole::kLoss)) == visit.end(),
54+
"Cannot add backward|loss operator before "
55+
"forward|loss operator.");
56+
PADDLE_ENFORCE(
57+
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
58+
"Cannot add backward operator before optimize operator.");
59+
break;
60+
case _INT(OpRole::kOptimize):
61+
case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched):
62+
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(),
63+
"Optimize operators must follow backward operator.");
64+
break;
65+
case _INT(OpRole::kLRSched):
66+
case _INT(OpRole::kDist):
67+
case _INT(OpRole::kRPC):
68+
case _INT(OpRole::kNotSpecified):
69+
break;
70+
default:
71+
LOG(FATAL) << "Unknown operator role. Don't add new role because "
72+
"you don't know what you are doing.";
73+
}
74+
}
75+
}
76+
#undef _INT
77+
}
78+
} // namespace
2679

2780
Graph::Graph(const ProgramDesc &program) : program_(program) {
81+
CheckProgram(program_);
2882
// Make the nodes id start from 0.
2983
Node::ResetId();
3084
auto var_nodes = InitFromProgram(program_);

paddle/fluid/inference/analysis/data_flow_graph_tester.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
16+
#include "paddle/fluid/framework/op_proto_maker.h"
1617
#include "paddle/fluid/framework/program_desc.h"
1718
#include "paddle/fluid/inference/analysis/ut_helper.h"
1819

@@ -130,6 +131,8 @@ void SetOp(framework::ProgramDesc* prog, const std::string& type,
130131
op->SetType(type);
131132
op->SetInput("Xs", inputs);
132133
op->SetOutput("Xs", outputs);
134+
op->SetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName(),
135+
static_cast<int>(framework::OpRole::kForward));
133136
}
134137

135138
TEST(DataFlowGraph, Build_IR_Graph) {

0 commit comments

Comments
 (0)