@@ -23,8 +23,59 @@ limitations under the License. */
23
23
namespace paddle {
24
24
namespace framework {
25
25
namespace ir {
26
+ namespace {
27
+ void CheckProgram (const ProgramDesc &program) {
28
+ std::map<int , bool > visit;
29
+ #define _INT (role ) static_cast <int >(role)
30
+
31
+ for (size_t i = 0 ; i < program.Size (); ++i) {
32
+ for (OpDesc *op : program.Block (i).AllOps ()) {
33
+ int role_id = boost::get<int >(
34
+ op->GetAttr (OpProtoAndCheckerMaker::OpRoleAttrName ()));
35
+ visit[role_id] = true ;
36
+ switch (role_id) {
37
+ case _INT (OpRole::kForward ):
38
+ PADDLE_ENFORCE (
39
+ visit.find (_INT (OpRole::kBackward )) == visit.end (),
40
+ " Cannot add forward operator before backward operator." );
41
+ break ;
42
+ case _INT (OpRole::kBackward ):
43
+ case _INT (OpRole::kBackward ) | _INT (OpRole::kLoss ):
44
+ PADDLE_ENFORCE (
45
+ visit.find (_INT (OpRole::kOptimize )) == visit.end (),
46
+ " Cannot add backward operator before optimize operator." );
47
+ break ;
48
+ case _INT (OpRole::kForward ) | _INT (OpRole::kLoss ):
49
+ PADDLE_ENFORCE (visit.find (_INT (OpRole::kBackward ) |
50
+ _INT (OpRole::kLoss )) == visit.end (),
51
+ " Cannot add backward|loss operator before "
52
+ " forward|loss operator." );
53
+ PADDLE_ENFORCE (
54
+ visit.find (_INT (OpRole::kOptimize )) == visit.end (),
55
+ " Cannot add backward operator before optimize operator." );
56
+ break ;
57
+ case _INT (OpRole::kOptimize ):
58
+ case _INT (OpRole::kOptimize ) | _INT (OpRole::kLRSched ):
59
+ PADDLE_ENFORCE (visit.find (_INT (OpRole::kBackward )) != visit.end (),
60
+ " Optimize operators must follow backward operator." );
61
+ break ;
62
+ case _INT (OpRole::kLRSched ):
63
+ case _INT (OpRole::kDist ):
64
+ case _INT (OpRole::kRPC ):
65
+ case _INT (OpRole::kNotSpecified ):
66
+ break ;
67
+ default :
68
+ LOG (FATAL) << " Unknown operator role. Don't add new role because "
69
+ " you don't know what you are doing." ;
70
+ }
71
+ }
72
+ }
73
+ #undef _INT
74
+ }
75
+ } // namespace
26
76
27
77
Graph::Graph (const ProgramDesc &program) : program_(program) {
78
+ CheckProgram (program_);
28
79
// Make the nodes id start from 0.
29
80
Node::ResetId ();
30
81
auto var_nodes = InitFromProgram (program_);
0 commit comments