Skip to content

Commit 51a538e

Browse files
Fix style and use enum
test=develop
1 parent ea3538d commit 51a538e

File tree

2 files changed

+51
-47
lines changed

2 files changed

+51
-47
lines changed

paddle/fluid/framework/ngraph_operator.cc

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ static std::map<proto::VarType::Type, ngraph::element::Type> pd2ng_type_map = {
3535
{proto::VarType::BOOL, ngraph::element::boolean},
3636
};
3737

38+
typedef enum { /* nGraph support state on ops */
39+
FULL_TRAIN, /* Support full ops for train */
40+
PARTIAL_TRAIN, /* Support partial ops for train */
41+
FULL_TEST, /* Support full list of ops for test */
42+
PARTIAL_TEST /* Support partial list of ops for test */
43+
} op_state;
44+
3845
class NgraphOperator {
3946
public:
4047
explicit NgraphOperator(const Scope& scope, const platform::Place& place,
@@ -44,33 +51,29 @@ class NgraphOperator {
4451
const std::unordered_set<std::string>& persist,
4552
const std::unordered_set<std::string>& fetches,
4653
const std::unordered_set<std::string>& post_op_inputs,
47-
int is_test_or_train)
48-
: scope(scope),
49-
place(place),
50-
fused_ops(ops),
51-
var_type_map(var_type_map),
52-
persistables(persist),
53-
fetches(fetches),
54-
post_op_inputs(post_op_inputs),
55-
is_test_or_train(is_test_or_train) {}
54+
op_state ng_op_state)
55+
: scope_(scope),
56+
place_(place),
57+
fused_ops_(ops),
58+
var_type_map_(var_type_map),
59+
persistables_(persist),
60+
fetches_(fetches),
61+
post_op_inputs_(post_op_inputs),
62+
ng_op_state_(ng_op_state) {}
5663

5764
void Run(const Scope& scope, const platform::Place& place) const;
5865

5966
private:
6067
static std::unordered_map<std::string, std::shared_ptr<ngraph::Function>>
6168
func_cache;
62-
const Scope& scope;
63-
const platform::Place& place;
64-
std::vector<std::shared_ptr<OperatorBase>> fused_ops;
65-
std::unordered_map<std::string, ngraph::element::Type> var_type_map;
66-
std::unordered_set<std::string> persistables;
67-
std::unordered_set<std::string> fetches;
68-
std::unordered_set<std::string> post_op_inputs;
69-
// 0 = default; 1 = (is_test && not is_complete)
70-
// 2 = (is_test && is_complete)
71-
// 3 = (is_training && not is_complete)
72-
// 4 = (is_training && is_complete)
73-
int is_test_or_train;
69+
const Scope& scope_;
70+
const platform::Place& place_;
71+
std::vector<std::shared_ptr<OperatorBase>> fused_ops_;
72+
std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
73+
std::unordered_set<std::string> persistables_;
74+
std::unordered_set<std::string> fetches_;
75+
std::unordered_set<std::string> post_op_inputs_;
76+
op_state ng_op_state_;
7477
};
7578

7679
std::vector<std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
@@ -131,19 +134,19 @@ FusedOperator::FusedOperator(
131134
const ProgramDesc& prog, size_t block_id,
132135
std::vector<std::unique_ptr<OperatorBase>>::iterator start,
133136
std::vector<std::unique_ptr<OperatorBase>>::iterator end,
134-
const std::string& type = "fused_op", const VariableNameMap& inputs = {},
135-
const VariableNameMap& outputs = {}, const AttributeMap& attrs = {})
137+
const std::string& type, const VariableNameMap& inputs,
138+
const VariableNameMap& outputs, const AttributeMap& attrs)
136139
: OperatorBase(type, inputs, outputs, attrs), pdesc(prog), block(block_id) {
137140
for (std::vector<std::unique_ptr<OperatorBase>>::iterator it = start;
138141
it != end; ++it) {
139-
fused_ops.push_back(std::move(*it));
142+
fused_ops_.push_back(std::move(*it));
140143
}
141144

142145
for (std::vector<std::unique_ptr<OperatorBase>>::iterator it = end;
143146
(*it)->Type() != kFetchOpType; ++it) {
144147
for (auto& var_name_item : (*it)->Inputs()) {
145148
for (auto& var_name : var_name_item.second) {
146-
post_op_inputs.insert(var_name);
149+
post_op_inputs_.insert(var_name);
147150
}
148151
}
149152
}
@@ -152,11 +155,11 @@ FusedOperator::FusedOperator(
152155
is_complete = true;
153156
}
154157

155-
process();
158+
Process();
156159
}
157160

158-
void FusedOperator::process() {
159-
auto& bdesc = pdesc.Block(block);
161+
void FusedOperator::Process() {
162+
auto& bdesc = pdesc_.Block(block_);
160163
for (auto& var : bdesc.AllVars()) {
161164
if (!(var->GetType() == proto::VarType::SELECTED_ROWS ||
162165
var->GetType() == proto::VarType::LOD_TENSOR ||
@@ -175,39 +178,40 @@ void FusedOperator::process() {
175178
PADDLE_THROW("Data type of var %s not found in pd2ng_type_map",
176179
var_name);
177180
}
178-
var_type_map[var_name] = pd2ng_type_map[pd_type];
181+
var_type_map_[var_name] = pd2ng_type_map[pd_type];
179182
}
180183

181184
if (var->Persistable()) {
182-
persistables.insert(var->Name());
185+
persistables_.insert(var->Name());
183186
}
184187
}
185188

186189
for (auto* op : bdesc.AllOps()) {
187190
if (op->Type() == kFetchOpType) {
188191
std::string fetch_target_name = op->Input("X")[0];
189-
fetches.insert(fetch_target_name);
192+
fetches_.insert(fetch_target_name);
190193
}
191194
}
192195
}
193196

194197
void FusedOperator::RunImpl(const Scope& scope,
195198
const platform::Place& place) const {
196-
int is_test_or_train = 1;
197-
auto& bdesc = pdesc.Block(block);
199+
op_state ng_op_state = PARTIAL_TEST;
200+
auto& bdesc = pdesc_.Block(block_);
198201
for (auto* op : bdesc.AllOps()) {
199202
if (op->Type().find("_grad") != std::string::npos) {
200-
is_test_or_train = 3;
203+
ng_op_state = PARTIAL_TRAIN;
201204
break;
202205
}
203206
}
204207

205-
if (is_complete) {
206-
is_test_or_train = is_test_or_train == 1 ? 2 : 4;
208+
if (is_full) {
209+
ng_op_state = ng_op_state == PARTIAL_TEST ? FULL_TEST : FULL_TRAIN;
207210
}
208211

209-
NgraphOperator ngraph_op(scope, place, fused_ops, var_type_map, persistables,
210-
fetches, post_op_inputs, is_test_or_train);
212+
NgraphOperator ngraph_op(scope, place, fused_ops_, var_type_map_,
213+
persistables_, fetches_, post_op_inputs_,
214+
ng_op_state);
211215
ngraph_op.Run(scope, place);
212216
}
213217

paddle/fluid/framework/ngraph_operator.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,16 @@ class FusedOperator : public OperatorBase {
5656
void RunImpl(const Scope& scope, const platform::Place& place) const final;
5757

5858
private:
59-
const ProgramDesc pdesc;
60-
size_t block;
61-
std::vector<std::shared_ptr<OperatorBase>> fused_ops;
62-
std::unordered_map<std::string, ngraph::element::Type> var_type_map;
63-
std::unordered_set<std::string> persistables;
64-
std::unordered_set<std::string> fetches;
65-
std::unordered_set<std::string> post_op_inputs;
66-
bool is_complete = false;
59+
const ProgramDesc pdesc_;
60+
size_t block_;
61+
std::vector<std::shared_ptr<OperatorBase>> fused_ops_;
62+
std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
63+
std::unordered_set<std::string> persistables_;
64+
std::unordered_set<std::string> fetches_;
65+
std::unordered_set<std::string> post_op_inputs_;
66+
bool is_full_ = false;
6767

68-
void process();
68+
void Process();
6969
};
7070
} // namespace framework
7171
} // namespace paddle

0 commit comments

Comments
 (0)