Skip to content

Commit 88a607c

Browse files
authored
Merge pull request #12541 from jacquesqiao/optimize-profiler
optimize profiler
2 parents 0fd2f71 + 954d680 commit 88a607c

File tree

9 files changed

+8
-26
lines changed

9 files changed

+8
-26
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
136136
platform::SetDeviceId(dev_id);
137137
#endif
138138
}
139+
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
140+
platform::RecordEvent record_event(Type(), pool.Get(place));
139141
RunImpl(scope, place);
140142
VLOG(10) << "+ " << DebugStringEx(&scope);
141143
}
@@ -639,9 +641,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
639641
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
640642
auto* dev_ctx = pool.Get(place);
641643

642-
// For profiling, don't move out of this function because that will result
643-
// in the failure of multi-GPU profiling.
644-
platform::RecordEvent record_event(Type(), dev_ctx);
645644
// check if op[type] has kernel registered.
646645
auto& all_op_kernels = AllOpKernels();
647646
auto kernels_iter = all_op_kernels.find(type_);

paddle/fluid/operators/feed_op.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ class FeedOp : public framework::OperatorBase {
3131
const platform::Place &place) const override {
3232
// get device context from pool
3333
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
34-
platform::RecordEvent record_event(Type(), dev_ctx);
3534

3635
auto feed_var_name = Input("X");
3736
auto *feed_var = scope.FindVar(feed_var_name);

paddle/fluid/operators/fetch_barrier_op.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,6 @@ class FetchBarrierOp : public framework::OperatorBase {
3636
void RunImpl(const framework::Scope& scope,
3737
const platform::Place& place) const override {
3838
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
39-
40-
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
41-
auto& ctx = *pool.Get(place);
42-
// For profiling
43-
platform::RecordEvent record_event(Type(), &ctx);
44-
4539
distributed::RPCClient* rpc_client =
4640
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
4741

paddle/fluid/operators/fetch_op.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@ class FetchOp : public framework::OperatorBase {
3030
private:
3131
void RunImpl(const framework::Scope &scope,
3232
const platform::Place &place) const override {
33-
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
34-
platform::RecordEvent record_event(Type(), pool.Get(place));
35-
3633
auto fetch_var_name = Input("X");
3734
auto *fetch_var = scope.FindVar(fetch_var_name);
3835
PADDLE_ENFORCE(fetch_var != nullptr,

paddle/fluid/operators/load_op.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,6 @@ class LoadOp : public framework::OperatorBase {
3131
private:
3232
void RunImpl(const framework::Scope &scope,
3333
const platform::Place &place) const override {
34-
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
35-
platform::RecordEvent record_event(Type(), dev_ctx);
36-
3734
// FIXME(yuyang18): We save variable to local file now, but we should change
3835
// it to save an output stream.
3936
auto filename = Attr<std::string>("file_path");

paddle/fluid/operators/recv_op.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ class RecvOp : public framework::OperatorBase {
4040

4141
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
4242
auto& ctx = *pool.Get(place);
43-
// For profiling
44-
platform::RecordEvent record_event(Type(), &ctx);
4543

4644
distributed::RPCClient* rpc_client =
4745
distributed::RPCClient::GetInstance<RPCCLIENT_T>();

paddle/fluid/operators/send_barrier_op.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,6 @@ class SendBarrierOp : public framework::OperatorBase {
3939
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
4040
bool sync_mode = Attr<bool>("sync_mode");
4141

42-
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
43-
auto& ctx = *pool.Get(place);
44-
// For profiling
45-
platform::RecordEvent record_event(Type(), &ctx);
46-
4742
distributed::RPCClient* rpc_client =
4843
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
4944

paddle/fluid/operators/send_op.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,6 @@ class SendOp : public framework::OperatorBase {
4242
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
4343
auto& ctx = *pool.Get(place);
4444

45-
// For profiling
46-
platform::RecordEvent record_event(Type(), &ctx);
47-
4845
distributed::RPCClient* rpc_client =
4946
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
5047

paddle/fluid/platform/profiler.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ Event::Event(EventType type, std::string name, uint32_t thread_id,
110110
has_cuda_ = dev_ctx ? platform::is_gpu_place(dev_ctx->GetPlace()) : false;
111111
if (has_cuda_) {
112112
auto* cuda_dev_ctx = static_cast<const CUDADeviceContext*>(dev_ctx);
113+
PADDLE_ENFORCE(cudaSetDevice(
114+
boost::get<platform::CUDAPlace>(cuda_dev_ctx->GetPlace()).device));
113115
PADDLE_ENFORCE(cudaGetDevice(&device_));
114116
PADDLE_ENFORCE(cudaEventCreate(&event_));
115117
auto stream = cuda_dev_ctx->stream();
@@ -176,6 +178,7 @@ void PopEvent(const std::string& name, const DeviceContext* dev_ctx) {
176178

177179
RecordEvent::RecordEvent(const std::string& name, const DeviceContext* dev_ctx)
178180
: is_enabled_(false), start_ns_(PosixInNsec()) {
181+
std::lock_guard<std::mutex> l(profiler_mu);
179182
if (g_state == ProfilerState::kDisabled) return;
180183
is_enabled_ = true;
181184
dev_ctx_ = dev_ctx;
@@ -186,6 +189,7 @@ RecordEvent::RecordEvent(const std::string& name, const DeviceContext* dev_ctx)
186189
}
187190

188191
RecordEvent::~RecordEvent() {
192+
std::lock_guard<std::mutex> l(profiler_mu);
189193
if (g_state == ProfilerState::kDisabled || !is_enabled_) return;
190194
DeviceTracer* tracer = GetDeviceTracer();
191195
if (tracer) {
@@ -198,13 +202,15 @@ RecordEvent::~RecordEvent() {
198202

199203
RecordBlock::RecordBlock(int block_id)
200204
: is_enabled_(false), start_ns_(PosixInNsec()) {
205+
std::lock_guard<std::mutex> l(profiler_mu);
201206
if (g_state == ProfilerState::kDisabled) return;
202207
is_enabled_ = true;
203208
SetCurBlock(block_id);
204209
name_ = string::Sprintf("block_%d", block_id);
205210
}
206211

207212
RecordBlock::~RecordBlock() {
213+
std::lock_guard<std::mutex> l(profiler_mu);
208214
if (g_state == ProfilerState::kDisabled || !is_enabled_) return;
209215
DeviceTracer* tracer = GetDeviceTracer();
210216
if (tracer) {

0 commit comments

Comments
 (0)