Skip to content

Commit c042137

Browse files
authored
Merge pull request #9043 from Xreki/core_inference_remove_clone
Remove unnecessary clone of program in C++ Executor.Run
2 parents df99b16 + 371c53f commit c042137

File tree

5 files changed

+44
-31
lines changed

5 files changed

+44
-31
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,11 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
113113
// and feed_holder_name. Raise exception when any mismatch is found.
114114
// Return true if the block has feed operators and holder of matching info.
115115
static bool has_feed_operators(
116-
BlockDesc* block, std::map<std::string, const LoDTensor*>& feed_targets,
116+
const BlockDesc& block,
117+
std::map<std::string, const LoDTensor*>& feed_targets,
117118
const std::string& feed_holder_name) {
118119
size_t feed_count = 0;
119-
for (auto* op : block->AllOps()) {
120+
for (auto* op : block.AllOps()) {
120121
if (op->Type() == kFeedOpType) {
121122
feed_count++;
122123
PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name,
@@ -135,7 +136,7 @@ static bool has_feed_operators(
135136
"The number of feed operators should match 'feed_targets'");
136137

137138
// When feed operator are present, so should be feed_holder
138-
auto var = block->FindVar(feed_holder_name);
139+
auto var = block.FindVar(feed_holder_name);
139140
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
140141
feed_holder_name);
141142
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
@@ -153,10 +154,10 @@ static bool has_feed_operators(
153154
// and fetch_holder_name. Raise exception when any mismatch is found.
154155
// Return true if the block has fetch operators and holder of matching info.
155156
static bool has_fetch_operators(
156-
BlockDesc* block, std::map<std::string, LoDTensor*>& fetch_targets,
157+
const BlockDesc& block, std::map<std::string, LoDTensor*>& fetch_targets,
157158
const std::string& fetch_holder_name) {
158159
size_t fetch_count = 0;
159-
for (auto* op : block->AllOps()) {
160+
for (auto* op : block.AllOps()) {
160161
if (op->Type() == kFetchOpType) {
161162
fetch_count++;
162163
PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name,
@@ -175,7 +176,7 @@ static bool has_fetch_operators(
175176
"The number of fetch operators should match 'fetch_targets'");
176177

177178
// When fetch operator are present, so should be fetch_holder
178-
auto var = block->FindVar(fetch_holder_name);
179+
auto var = block.FindVar(fetch_holder_name);
179180
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
180181
fetch_holder_name);
181182
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
@@ -192,10 +193,19 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
192193
const std::string& feed_holder_name,
193194
const std::string& fetch_holder_name) {
194195
platform::RecordBlock b(kProgramId);
195-
auto* copy_program = new ProgramDesc(program);
196+
bool has_feed_ops =
197+
has_feed_operators(program.Block(0), feed_targets, feed_holder_name);
198+
bool has_fetch_ops =
199+
has_fetch_operators(program.Block(0), fetch_targets, fetch_holder_name);
200+
201+
ProgramDesc* copy_program = const_cast<ProgramDesc*>(&program);
202+
if (!has_feed_ops || !has_fetch_ops) {
203+
copy_program = std::unique_ptr<ProgramDesc>(new ProgramDesc(program)).get();
204+
}
205+
196206
auto* global_block = copy_program->MutableBlock(0);
197207

198-
if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) {
208+
if (!has_feed_ops) {
199209
// create feed_holder variable
200210
auto* feed_holder = global_block->Var(feed_holder_name);
201211
feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
@@ -228,7 +238,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
228238
}
229239
}
230240

231-
if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) {
241+
if (!has_fetch_ops) {
232242
// create fetch_holder variable
233243
auto* fetch_holder = global_block->Var(fetch_holder_name);
234244
fetch_holder->SetType(proto::VarType::FETCH_LIST);
@@ -262,8 +272,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
262272
GetFetchVariable(*scope, fetch_holder_name, idx);
263273
}
264274
}
265-
266-
delete copy_program;
267275
}
268276

269277
ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program,
@@ -313,9 +321,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
313321
} // if (create_vars)
314322

315323
for (auto& op : ctx->ops_) {
316-
VLOG(4) << place_ << " " << op->DebugStringEx(local_scope);
317-
op->Run(*local_scope, place_);
318324
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
325+
op->Run(*local_scope, place_);
319326

320327
if (FLAGS_benchmark) {
321328
VLOG(2) << "Memory used after operator " + op->Type() + " running: "

paddle/fluid/operators/conv_op.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,16 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
7070

7171
framework::OpKernelType ConvOp::GetExpectedKernelType(
7272
const framework::ExecutionContext& ctx) const {
73-
framework::LibraryType library_{framework::LibraryType::kPlain};
73+
framework::LibraryType library{framework::LibraryType::kPlain};
7474
#ifdef PADDLE_WITH_CUDA
7575
if (platform::CanCUDNNBeUsed(ctx)) {
76-
library_ = framework::LibraryType::kCUDNN;
76+
library = framework::LibraryType::kCUDNN;
7777
}
7878
#endif
7979
#ifdef PADDLE_WITH_MKLDNN
80-
if (library_ == framework::LibraryType::kPlain &&
80+
if (library == framework::LibraryType::kPlain &&
8181
platform::CanMKLDNNBeUsed(ctx)) {
82-
library_ = framework::LibraryType::kMKLDNN;
82+
library = framework::LibraryType::kMKLDNN;
8383
}
8484
#endif
8585

@@ -91,15 +91,15 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
9191
"input and filter data type should be consistent");
9292

9393
if (input_data_type == framework::proto::VarType::FP16) {
94-
PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN,
94+
PADDLE_ENFORCE_EQ(library, framework::LibraryType::kCUDNN,
9595
"float16 can only be used when CUDNN is used");
9696
}
9797

9898
std::string data_format = ctx.Attr<std::string>("data_format");
9999
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
100-
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
101-
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
102-
library_);
100+
framework::DataLayout layout = framework::StringToDataLayout(data_format);
101+
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
102+
library);
103103
}
104104

105105
Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)

paddle/fluid/operators/feed_op.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#include "paddle/fluid/framework/feed_fetch_type.h"
1616
#include "paddle/fluid/framework/op_registry.h"
1717
#include "paddle/fluid/framework/operator.h"
18+
#include "paddle/fluid/platform/profiler.h"
1819

1920
namespace paddle {
2021
namespace operators {
@@ -28,6 +29,10 @@ class FeedOp : public framework::OperatorBase {
2829
private:
2930
void RunImpl(const framework::Scope &scope,
3031
const platform::Place &place) const override {
32+
// get device context from pool
33+
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
34+
platform::RecordEvent record_event(Type(), dev_ctx);
35+
3136
auto feed_var_name = Input("X");
3237
auto *feed_var = scope.FindVar(feed_var_name);
3338

@@ -50,14 +55,10 @@ class FeedOp : public framework::OperatorBase {
5055
auto &feed_item = feed_list.at(static_cast<size_t>(col));
5156
auto *out_item = out_var->GetMutable<framework::FeedFetchType>();
5257

53-
// get device context from pool
54-
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
55-
auto &dev_ctx = *pool.Get(place);
56-
5758
if (platform::is_same_place(feed_item.place(), place)) {
5859
out_item->ShareDataWith(feed_item);
5960
} else {
60-
framework::TensorCopy(feed_item, place, dev_ctx, out_item);
61+
framework::TensorCopy(feed_item, place, *dev_ctx, out_item);
6162
}
6263
out_item->set_lod(feed_item.lod());
6364
}

paddle/fluid/operators/fetch_op.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#include "paddle/fluid/framework/feed_fetch_type.h"
1616
#include "paddle/fluid/framework/op_registry.h"
1717
#include "paddle/fluid/platform/device_context.h"
18+
#include "paddle/fluid/platform/profiler.h"
1819

1920
namespace paddle {
2021
namespace operators {
@@ -29,6 +30,9 @@ class FetchOp : public framework::OperatorBase {
2930
private:
3031
void RunImpl(const framework::Scope &scope,
3132
const platform::Place &place) const override {
33+
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
34+
platform::RecordEvent record_event(Type(), pool.Get(place));
35+
3236
auto fetch_var_name = Input("X");
3337
auto *fetch_var = scope.FindVar(fetch_var_name);
3438
PADDLE_ENFORCE(fetch_var != nullptr,
@@ -53,7 +57,6 @@ class FetchOp : public framework::OperatorBase {
5357

5458
// FIXME(yuyang18): Should we assume the fetch operator always generate
5559
// CPU outputs?
56-
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
5760
auto &dev_ctx = *pool.Get(src_item.place());
5861

5962
TensorCopy(src_item, platform::CPUPlace(), dev_ctx, &dst_item);

paddle/fluid/operators/load_op.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515

1616
#include "paddle/fluid/framework/op_registry.h"
1717
#include "paddle/fluid/platform/device_context.h"
18+
#include "paddle/fluid/platform/profiler.h"
1819

1920
namespace paddle {
2021
namespace operators {
@@ -29,6 +30,9 @@ class LoadOp : public framework::OperatorBase {
2930
private:
3031
void RunImpl(const framework::Scope &scope,
3132
const platform::Place &place) const override {
33+
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
34+
platform::RecordEvent record_event(Type(), dev_ctx);
35+
3236
auto filename = Attr<std::string>("file_path");
3337
std::ifstream fin(filename);
3438
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
@@ -41,9 +45,7 @@ class LoadOp : public framework::OperatorBase {
4145

4246
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
4347

44-
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
45-
auto &dev_ctx = *pool.Get(place);
46-
DeserializeFromStream(fin, tensor, dev_ctx);
48+
DeserializeFromStream(fin, tensor, *dev_ctx);
4749

4850
if (platform::is_gpu_place(place)) {
4951
// copy CPU to GPU
@@ -55,7 +57,7 @@ class LoadOp : public framework::OperatorBase {
5557
out_var->Clear();
5658
tensor = out_var->GetMutable<framework::LoDTensor>();
5759
tensor->set_lod(cpu_tensor.lod());
58-
TensorCopy(cpu_tensor, place, dev_ctx, tensor);
60+
TensorCopy(cpu_tensor, place, *dev_ctx, tensor);
5961
}
6062
}
6163
};

0 commit comments

Comments
 (0)