Skip to content

Commit 371c53f

Browse files
committed
Add profiling event in feed, fetch and load op.
1 parent 253ba66 commit 371c53f

File tree

4 files changed

+24
-18
lines changed

4 files changed

+24
-18
lines changed

paddle/fluid/operators/conv_op.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,16 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
7070

7171
framework::OpKernelType ConvOp::GetExpectedKernelType(
7272
const framework::ExecutionContext& ctx) const {
73-
framework::LibraryType library_{framework::LibraryType::kPlain};
73+
framework::LibraryType library{framework::LibraryType::kPlain};
7474
#ifdef PADDLE_WITH_CUDA
7575
if (platform::CanCUDNNBeUsed(ctx)) {
76-
library_ = framework::LibraryType::kCUDNN;
76+
library = framework::LibraryType::kCUDNN;
7777
}
7878
#endif
7979
#ifdef PADDLE_WITH_MKLDNN
80-
if (library_ == framework::LibraryType::kPlain &&
80+
if (library == framework::LibraryType::kPlain &&
8181
platform::CanMKLDNNBeUsed(ctx)) {
82-
library_ = framework::LibraryType::kMKLDNN;
82+
library = framework::LibraryType::kMKLDNN;
8383
}
8484
#endif
8585

@@ -91,15 +91,15 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
9191
"input and filter data type should be consistent");
9292

9393
if (input_data_type == framework::proto::VarType::FP16) {
94-
PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN,
94+
PADDLE_ENFORCE_EQ(library, framework::LibraryType::kCUDNN,
9595
"float16 can only be used when CUDNN is used");
9696
}
9797

9898
std::string data_format = ctx.Attr<std::string>("data_format");
9999
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
100-
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
101-
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
102-
library_);
100+
framework::DataLayout layout = framework::StringToDataLayout(data_format);
101+
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
102+
library);
103103
}
104104

105105
Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)

paddle/fluid/operators/feed_op.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#include "paddle/fluid/framework/feed_fetch_type.h"
1616
#include "paddle/fluid/framework/op_registry.h"
1717
#include "paddle/fluid/framework/operator.h"
18+
#include "paddle/fluid/platform/profiler.h"
1819

1920
namespace paddle {
2021
namespace operators {
@@ -28,6 +29,10 @@ class FeedOp : public framework::OperatorBase {
2829
private:
2930
void RunImpl(const framework::Scope &scope,
3031
const platform::Place &place) const override {
32+
// get device context from pool
33+
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
34+
platform::RecordEvent record_event(Type(), dev_ctx);
35+
3136
auto feed_var_name = Input("X");
3237
auto *feed_var = scope.FindVar(feed_var_name);
3338

@@ -50,14 +55,10 @@ class FeedOp : public framework::OperatorBase {
5055
auto &feed_item = feed_list.at(static_cast<size_t>(col));
5156
auto *out_item = out_var->GetMutable<framework::FeedFetchType>();
5257

53-
// get device context from pool
54-
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
55-
auto &dev_ctx = *pool.Get(place);
56-
5758
if (platform::is_same_place(feed_item.place(), place)) {
5859
out_item->ShareDataWith(feed_item);
5960
} else {
60-
framework::TensorCopy(feed_item, place, dev_ctx, out_item);
61+
framework::TensorCopy(feed_item, place, *dev_ctx, out_item);
6162
}
6263
out_item->set_lod(feed_item.lod());
6364
}

paddle/fluid/operators/fetch_op.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#include "paddle/fluid/framework/feed_fetch_type.h"
1616
#include "paddle/fluid/framework/op_registry.h"
1717
#include "paddle/fluid/platform/device_context.h"
18+
#include "paddle/fluid/platform/profiler.h"
1819

1920
namespace paddle {
2021
namespace operators {
@@ -29,6 +30,9 @@ class FetchOp : public framework::OperatorBase {
2930
private:
3031
void RunImpl(const framework::Scope &scope,
3132
const platform::Place &place) const override {
33+
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
34+
platform::RecordEvent record_event(Type(), pool.Get(place));
35+
3236
auto fetch_var_name = Input("X");
3337
auto *fetch_var = scope.FindVar(fetch_var_name);
3438
PADDLE_ENFORCE(fetch_var != nullptr,
@@ -53,7 +57,6 @@ class FetchOp : public framework::OperatorBase {
5357

5458
// FIXME(yuyang18): Should we assume the fetch operator always generate
5559
// CPU outputs?
56-
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
5760
auto &dev_ctx = *pool.Get(src_item.place());
5861

5962
TensorCopy(src_item, platform::CPUPlace(), dev_ctx, &dst_item);

paddle/fluid/operators/load_op.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515

1616
#include "paddle/fluid/framework/op_registry.h"
1717
#include "paddle/fluid/platform/device_context.h"
18+
#include "paddle/fluid/platform/profiler.h"
1819

1920
namespace paddle {
2021
namespace operators {
@@ -29,6 +30,9 @@ class LoadOp : public framework::OperatorBase {
2930
private:
3031
void RunImpl(const framework::Scope &scope,
3132
const platform::Place &place) const override {
33+
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
34+
platform::RecordEvent record_event(Type(), dev_ctx);
35+
3236
auto filename = Attr<std::string>("file_path");
3337
std::ifstream fin(filename);
3438
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
@@ -41,9 +45,7 @@ class LoadOp : public framework::OperatorBase {
4145

4246
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
4347

44-
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
45-
auto &dev_ctx = *pool.Get(place);
46-
DeserializeFromStream(fin, tensor, dev_ctx);
48+
DeserializeFromStream(fin, tensor, *dev_ctx);
4749

4850
if (platform::is_gpu_place(place)) {
4951
// copy CPU to GPU
@@ -55,7 +57,7 @@ class LoadOp : public framework::OperatorBase {
5557
out_var->Clear();
5658
tensor = out_var->GetMutable<framework::LoDTensor>();
5759
tensor->set_lod(cpu_tensor.lod());
58-
TensorCopy(cpu_tensor, place, dev_ctx, tensor);
60+
TensorCopy(cpu_tensor, place, *dev_ctx, tensor);
5961
}
6062
}
6163
};

0 commit comments

Comments
 (0)