Skip to content

Commit 5839e32

Browse files
committed
add program check
test=develop
1 parent 5577f9b commit 5839e32

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

paddle/fluid/framework/ir/graph.cc

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,59 @@ limitations under the License. */
2323
namespace paddle {
2424
namespace framework {
2525
namespace ir {
26+
namespace {
27+
void CheckProgram(const ProgramDesc &program) {
28+
std::map<int, bool> visit;
29+
#define _INT(role) static_cast<int>(role)
30+
31+
for (size_t i = 0; i < program.Size(); ++i) {
32+
for (OpDesc *op : program.Block(i).AllOps()) {
33+
int role_id = boost::get<int>(
34+
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
35+
visit[role_id] = true;
36+
switch (role_id) {
37+
case _INT(OpRole::kForward):
38+
PADDLE_ENFORCE(
39+
visit.find(_INT(OpRole::kBackward)) == visit.end(),
40+
"Cannot add forward operator before backward operator.");
41+
break;
42+
case _INT(OpRole::kBackward):
43+
case _INT(OpRole::kBackward) | _INT(OpRole::kLoss):
44+
PADDLE_ENFORCE(
45+
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
46+
"Cannot add backward operator before optimize operator.");
47+
break;
48+
case _INT(OpRole::kForward) | _INT(OpRole::kLoss):
49+
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) |
50+
_INT(OpRole::kLoss)) == visit.end(),
51+
"Cannot add backward|loss operator before "
52+
"forward|loss operator.");
53+
PADDLE_ENFORCE(
54+
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
55+
"Cannot add backward operator before optimize operator.");
56+
break;
57+
case _INT(OpRole::kOptimize):
58+
case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched):
59+
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(),
60+
"Optimize operators must follow backward operator.");
61+
break;
62+
case _INT(OpRole::kLRSched):
63+
case _INT(OpRole::kDist):
64+
case _INT(OpRole::kRPC):
65+
case _INT(OpRole::kNotSpecified):
66+
break;
67+
default:
68+
LOG(FATAL) << "Unknown operator role. Don't add new role because "
69+
"you don't know what you are doing.";
70+
}
71+
}
72+
}
73+
#undef _INT
74+
}
75+
} // namespace
2676

2777
Graph::Graph(const ProgramDesc &program) : program_(program) {
78+
CheckProgram(program_);
2879
// Make the nodes id start from 0.
2980
Node::ResetId();
3081
auto var_nodes = InitFromProgram(program_);

0 commit comments

Comments
 (0)