Skip to content

Commit ed61d67

Browse files
author
chengduo
authored
Fix the interface of Pass::Apply (#16484)
* modify the interface of Pass::Allay test=develop * Polish code test=develop * Fix Travis CI test=develop * fix Pass::Apply interface test=develop * Fix Travis CI test=develop
1 parent 59f75ec commit ed61d67

File tree

122 files changed

+370
-539
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

122 files changed

+370
-539
lines changed

paddle/fluid/framework/details/all_reduce_deps_pass.cc

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ VarHandle* GetValidInput(const OpHandleBase* a) {
4242
return nullptr;
4343
}
4444

45-
std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
46-
std::unique_ptr<ir::Graph> graph) const {
45+
void AllReduceDepsPass::ApplyImpl(ir::Graph* graph) const {
4746
auto graph_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
4847

4948
// get vars order
@@ -131,8 +130,6 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
131130
VLOG(10) << "pre_op:" << pre_op->DebugString()
132131
<< ", op:" << op->DebugString();
133132
}
134-
135-
return graph;
136133
}
137134

138135
} // namespace details

paddle/fluid/framework/details/all_reduce_deps_pass.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ namespace details {
2424
// TODO(gongwb): overlap allreduce with backward computation.
2525
class AllReduceDepsPass : public ir::Pass {
2626
protected:
27-
std::unique_ptr<ir::Graph> ApplyImpl(
28-
std::unique_ptr<ir::Graph> graph) const override;
27+
void ApplyImpl(ir::Graph* graph) const override;
2928
};
3029

3130
} // namespace details

paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.cc

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ static framework::proto::VarType::Type kDefaultDtype =
4646

4747
class AllocContinuousSpaceForGradPass : public ir::Pass {
4848
protected:
49-
std::unique_ptr<ir::Graph> ApplyImpl(
50-
std::unique_ptr<ir::Graph> graph) const override {
49+
void ApplyImpl(ir::Graph *graph) const override {
5150
ir::Graph &result = *graph;
5251

5352
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
@@ -65,7 +64,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
6564

6665
if (params_grads.size() == 0) {
6766
VLOG(10) << "Doesn't find gradients";
68-
return std::move(graph);
67+
return;
6968
}
7069

7170
std::unordered_map<std::string, ir::Node *> vars;
@@ -124,8 +123,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
124123

125124
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars,
126125
fused_var_name, params_grads);
127-
128-
return std::move(graph);
129126
}
130127

131128
template <typename AttrType>

paddle/fluid/framework/details/build_strategy.cc

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -204,15 +204,16 @@ bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
204204
return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0;
205205
}
206206

207-
std::unique_ptr<ir::Graph> BuildStrategy::Apply(
208-
std::unique_ptr<ir::Graph> graph,
209-
const std::vector<platform::Place> &places,
210-
const std::string &loss_var_name, const std::vector<Scope *> &local_scopes,
211-
const size_t &nranks,
207+
ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
208+
const std::vector<platform::Place> &places,
209+
const std::string &loss_var_name,
210+
const std::vector<Scope *> &local_scopes,
211+
const size_t &nranks,
212212
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
213-
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const {
213+
const bool use_cuda,
214+
platform::NCCLContextMap *nccl_ctxs) const {
214215
#else
215-
const bool use_cuda) const {
216+
const bool use_cuda) const {
216217
#endif
217218
// Create a default one if not finalized by user.
218219
CreatePassesFromStrategy(false);
@@ -265,7 +266,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
265266
}
266267
}
267268
VLOG(3) << "Start Apply Pass " << pass->Type();
268-
graph = pass->Apply(std::move(graph));
269+
graph = pass->Apply(graph);
269270
VLOG(3) << "Finish Apply Pass " << pass->Type();
270271
}
271272
return graph;

paddle/fluid/framework/details/build_strategy.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,16 +120,15 @@ struct BuildStrategy {
120120

121121
// Apply the passes built by the pass_builder_. The passes will be
122122
// applied to the Program and output an ir::Graph.
123-
std::unique_ptr<ir::Graph> Apply(std::unique_ptr<ir::Graph> graph,
124-
const std::vector<platform::Place> &places,
125-
const std::string &loss_var_name,
126-
const std::vector<Scope *> &local_scopes,
127-
const size_t &nranks,
123+
ir::Graph *Apply(ir::Graph *graph, const std::vector<platform::Place> &places,
124+
const std::string &loss_var_name,
125+
const std::vector<Scope *> &local_scopes,
126+
const size_t &nranks,
128127
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
129-
const bool use_cuda,
130-
platform::NCCLContextMap *nccl_ctxs) const;
128+
const bool use_cuda,
129+
platform::NCCLContextMap *nccl_ctxs) const;
131130
#else
132-
const bool use_cuda) const;
131+
const bool use_cuda) const;
133132
#endif
134133

135134
// If set true, ParallelExecutor would build the main_program into multiple

paddle/fluid/framework/details/eager_deletion_pass.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,10 @@ static OpToVarNameSetMap ShrinkGCVars(
170170

171171
class EagerDeletionPass : public ir::Pass {
172172
protected:
173-
std::unique_ptr<ir::Graph> ApplyImpl(
174-
std::unique_ptr<ir::Graph> graph) const override;
173+
void ApplyImpl(ir::Graph *graph) const override;
175174
};
176175

177-
std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
178-
std::unique_ptr<ir::Graph> graph) const {
176+
void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
179177
auto &ref_cnts =
180178
Get<std::vector<AtomicReferenceCountMap>>(kRuntimeReferenceCount);
181179
PADDLE_ENFORCE(ref_cnts.empty(),
@@ -240,7 +238,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
240238

241239
auto while_op_eager_deletion_pass =
242240
ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass");
243-
return while_op_eager_deletion_pass->Apply(std::move(graph));
241+
while_op_eager_deletion_pass->Apply(graph);
244242
}
245243

246244
} // namespace details

paddle/fluid/framework/details/fuse_all_reduce_op_pass.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ namespace details {
2828

2929
class FuseAllReduceOpPass : public ir::Pass {
3030
protected:
31-
std::unique_ptr<ir::Graph> ApplyImpl(
32-
std::unique_ptr<ir::Graph> graph) const override {
31+
void ApplyImpl(ir::Graph *graph) const override {
3332
ir::Graph &result = *graph;
3433

3534
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
@@ -71,7 +70,7 @@ class FuseAllReduceOpPass : public ir::Pass {
7170

7271
VLOG(10) << "Find all_reduce_ops: " << all_reduce_ops.size();
7372
if (all_reduce_ops.size() == 0) {
74-
return std::move(graph);
73+
return;
7574
}
7675

7776
PADDLE_ENFORCE_EQ(all_reduce_ops.size(), grads.size(),
@@ -99,7 +98,6 @@ class FuseAllReduceOpPass : public ir::Pass {
9998
group_all_reduce_ops, &result);
10099
#endif
101100
}
102-
return std::move(graph);
103101
}
104102

105103
void InsertFusedAllReduce(const std::vector<platform::Place> &places,

paddle/fluid/framework/details/inplace_op_pass.cc

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,22 +144,19 @@ void InplacePass::InitSSAGraphNodes() const {
144144
}
145145
}
146146

147-
std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
148-
std::unique_ptr<ir::Graph> graph) const {
147+
void InplacePass::ApplyImpl(ir::Graph* graph) const {
149148
var_nodes_.clear();
150-
view_.Build(graph.get());
149+
view_.Build(graph);
151150
InitSSAGraphNodes();
152151

153152
auto cnt = 0;
154153
for (auto* op : view_.AllOps()) {
155154
VLOG(4) << "Handle op " << cnt++ << ": " << op->Name();
156155
if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name()))
157156
continue;
158-
TryInplaceOpInputOutput(op, graph.get());
157+
TryInplaceOpInputOutput(op, graph);
159158
}
160159
// graph->ResolveHazard(var_nodes_);
161-
162-
return graph;
163160
}
164161

165162
void InplacePass::InplaceModifyDesc(const std::string& var,

paddle/fluid/framework/details/inplace_op_pass.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@ class InplacePass : public ir::Pass {
6969
InplacePass();
7070

7171
protected:
72-
std::unique_ptr<ir::Graph> ApplyImpl(
73-
std::unique_ptr<ir::Graph> graph) const override;
72+
void ApplyImpl(ir::Graph* graph) const override;
7473

7574
void InitSSAGraphNodes() const;
7675

paddle/fluid/framework/details/memory_optimize_pass.cc

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ namespace paddle {
4444
namespace framework {
4545
namespace details {
4646

47-
std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
48-
std::unique_ptr<ir::Graph> graph) const {
47+
void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const {
4948
auto nodes = graph->Nodes();
5049
CollectSkipVarsSet(nodes);
5150

@@ -113,7 +112,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
113112

114113
cfg_->RenameVarInCFGGraph(var_name, cache_name, idx);
115114
RenameVarInGraphDesc(var_name, cache_name, idx);
116-
RenameVarInGraphNode(var_name, cache_name, idx, graph.get());
115+
RenameVarInGraphNode(var_name, cache_name, idx, graph);
117116
pool_.Erase(cache_name);
118117
}
119118
}
@@ -128,8 +127,6 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
128127
}
129128
}
130129
graph->ResolveHazard(var_nodes_);
131-
132-
return graph;
133130
}
134131

135132
void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {

0 commit comments

Comments
 (0)