@@ -35,6 +35,13 @@ static std::map<proto::VarType::Type, ngraph::element::Type> pd2ng_type_map = {
35
35
{proto::VarType::BOOL, ngraph::element::boolean},
36
36
};
37
37
38
+ typedef enum { /* nGraph support state on ops */
39
+ FULL_TRAIN, /* Support full ops for train */
40
+ PARTIAL_TRAIN, /* Support partial ops for train */
41
+ FULL_TEST, /* Support full list of ops for test */
42
+ PARTIAL_TEST /* Support partial list of ops for test */
43
+ } op_state;
44
+
38
45
class NgraphOperator {
39
46
public:
40
47
explicit NgraphOperator (const Scope& scope, const platform::Place& place,
@@ -44,33 +51,29 @@ class NgraphOperator {
44
51
const std::unordered_set<std::string>& persist,
45
52
const std::unordered_set<std::string>& fetches,
46
53
const std::unordered_set<std::string>& post_op_inputs,
47
- int is_test_or_train )
48
- : scope (scope),
49
- place (place),
50
- fused_ops (ops),
51
- var_type_map (var_type_map),
52
- persistables (persist),
53
- fetches (fetches),
54
- post_op_inputs (post_op_inputs),
55
- is_test_or_train(is_test_or_train ) {}
54
+ op_state ng_op_state )
55
+ : scope_ (scope),
56
+ place_ (place),
57
+ fused_ops_ (ops),
58
+ var_type_map_ (var_type_map),
59
+ persistables_ (persist),
60
+ fetches_ (fetches),
61
+ post_op_inputs_ (post_op_inputs),
62
+ ng_op_state_(ng_op_state ) {}
56
63
57
64
void Run (const Scope& scope, const platform::Place& place) const ;
58
65
59
66
private:
60
67
static std::unordered_map<std::string, std::shared_ptr<ngraph::Function>>
61
68
func_cache;
62
- const Scope& scope;
63
- const platform::Place& place;
64
- std::vector<std::shared_ptr<OperatorBase>> fused_ops;
65
- std::unordered_map<std::string, ngraph::element::Type> var_type_map;
66
- std::unordered_set<std::string> persistables;
67
- std::unordered_set<std::string> fetches;
68
- std::unordered_set<std::string> post_op_inputs;
69
- // 0 = default; 1 = (is_test && not is_complete)
70
- // 2 = (is_test && is_complete)
71
- // 3 = (is_training && not is_complete)
72
- // 4 = (is_training && is_complete)
73
- int is_test_or_train;
69
+ const Scope& scope_;
70
+ const platform::Place& place_;
71
+ std::vector<std::shared_ptr<OperatorBase>> fused_ops_;
72
+ std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
73
+ std::unordered_set<std::string> persistables_;
74
+ std::unordered_set<std::string> fetches_;
75
+ std::unordered_set<std::string> post_op_inputs_;
76
+ op_state ng_op_state_;
74
77
};
75
78
76
79
std::vector<std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
@@ -131,19 +134,19 @@ FusedOperator::FusedOperator(
131
134
const ProgramDesc& prog, size_t block_id,
132
135
std::vector<std::unique_ptr<OperatorBase>>::iterator start,
133
136
std::vector<std::unique_ptr<OperatorBase>>::iterator end,
134
- const std::string& type = " fused_op " , const VariableNameMap& inputs = {} ,
135
- const VariableNameMap& outputs = {} , const AttributeMap& attrs = {} )
137
+ const std::string& type, const VariableNameMap& inputs,
138
+ const VariableNameMap& outputs, const AttributeMap& attrs)
136
139
: OperatorBase(type, inputs, outputs, attrs), pdesc(prog), block(block_id) {
137
140
for (std::vector<std::unique_ptr<OperatorBase>>::iterator it = start;
138
141
it != end; ++it) {
139
- fused_ops .push_back (std::move (*it));
142
+ fused_ops_ .push_back (std::move (*it));
140
143
}
141
144
142
145
for (std::vector<std::unique_ptr<OperatorBase>>::iterator it = end;
143
146
(*it)->Type () != kFetchOpType ; ++it) {
144
147
for (auto & var_name_item : (*it)->Inputs ()) {
145
148
for (auto & var_name : var_name_item.second ) {
146
- post_op_inputs .insert (var_name);
149
+ post_op_inputs_ .insert (var_name);
147
150
}
148
151
}
149
152
}
@@ -152,11 +155,11 @@ FusedOperator::FusedOperator(
152
155
is_complete = true ;
153
156
}
154
157
155
- process ();
158
+ Process ();
156
159
}
157
160
158
- void FusedOperator::process () {
159
- auto & bdesc = pdesc .Block (block );
161
+ void FusedOperator::Process () {
162
+ auto & bdesc = pdesc_ .Block (block_ );
160
163
for (auto & var : bdesc.AllVars ()) {
161
164
if (!(var->GetType () == proto::VarType::SELECTED_ROWS ||
162
165
var->GetType () == proto::VarType::LOD_TENSOR ||
@@ -175,39 +178,40 @@ void FusedOperator::process() {
175
178
PADDLE_THROW (" Data type of var %s not found in pd2ng_type_map" ,
176
179
var_name);
177
180
}
178
- var_type_map [var_name] = pd2ng_type_map[pd_type];
181
+ var_type_map_ [var_name] = pd2ng_type_map[pd_type];
179
182
}
180
183
181
184
if (var->Persistable ()) {
182
- persistables .insert (var->Name ());
185
+ persistables_ .insert (var->Name ());
183
186
}
184
187
}
185
188
186
189
for (auto * op : bdesc.AllOps ()) {
187
190
if (op->Type () == kFetchOpType ) {
188
191
std::string fetch_target_name = op->Input (" X" )[0 ];
189
- fetches .insert (fetch_target_name);
192
+ fetches_ .insert (fetch_target_name);
190
193
}
191
194
}
192
195
}
193
196
194
197
void FusedOperator::RunImpl (const Scope& scope,
195
198
const platform::Place& place) const {
196
- int is_test_or_train = 1 ;
197
- auto & bdesc = pdesc .Block (block );
199
+ op_state ng_op_state = PARTIAL_TEST ;
200
+ auto & bdesc = pdesc_ .Block (block_ );
198
201
for (auto * op : bdesc.AllOps ()) {
199
202
if (op->Type ().find (" _grad" ) != std::string::npos) {
200
- is_test_or_train = 3 ;
203
+ ng_op_state = PARTIAL_TRAIN ;
201
204
break ;
202
205
}
203
206
}
204
207
205
- if (is_complete ) {
206
- is_test_or_train = is_test_or_train == 1 ? 2 : 4 ;
208
+ if (is_full ) {
209
+ ng_op_state = ng_op_state == PARTIAL_TEST ? FULL_TEST : FULL_TRAIN ;
207
210
}
208
211
209
- NgraphOperator ngraph_op (scope, place, fused_ops, var_type_map, persistables,
210
- fetches, post_op_inputs, is_test_or_train);
212
+ NgraphOperator ngraph_op (scope, place, fused_ops_, var_type_map_,
213
+ persistables_, fetches_, post_op_inputs_,
214
+ ng_op_state);
211
215
ngraph_op.Run (scope, place);
212
216
}
213
217
0 commit comments