Skip to content

Commit a8e8507

Browse files
committed
Refine the profile codes for inference.
1 parent b825c79 commit a8e8507

File tree

2 files changed

+23
-11
lines changed

2 files changed

+23
-11
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
7474
platform::SetDeviceId(dev_id);
7575
#endif
7676
}
77+
// profile
78+
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
79+
platform::RecordEvent record_event(Type(), dev_ctx);
7780
RunImpl(scope, place);
7881
}
7982

@@ -497,9 +500,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
497500
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
498501
this->InferShape(&infer_shape_ctx);
499502
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
500-
auto dev_ctx = pool.Get(place);
501-
// profile
502-
platform::RecordEvent record_event(Type(), dev_ctx);
503+
auto* dev_ctx = pool.Get(place);
503504
// check if op[type] has kernel registered.
504505
auto& all_op_kernels = AllOpKernels();
505506
auto kernels_iter = all_op_kernels.find(type_);

paddle/fluid/inference/tests/test_helper.h

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,11 @@ void TestInference(const std::string& dirname,
115115
#endif
116116
}
117117

118-
// Enable the profiler
119-
paddle::platform::EnableProfiler(state);
120-
121118
// 2. Initialize the inference_program and load parameters
122119
std::unique_ptr<paddle::framework::ProgramDesc> inference_program;
120+
121+
// Enable the profiler
122+
paddle::platform::EnableProfiler(state);
123123
{
124124
paddle::platform::RecordEvent record_event(
125125
"init_program",
@@ -143,6 +143,10 @@ void TestInference(const std::string& dirname,
143143
inference_program = paddle::inference::Load(executor, *scope, dirname);
144144
}
145145
}
146+
// Disable the profiler and print the timing information
147+
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,
148+
"load_program_profiler.txt");
149+
paddle::platform::ResetProfiler();
146150

147151
// 3. Get the feed_target_names and fetch_target_names
148152
const std::vector<std::string>& feed_target_names =
@@ -165,6 +169,12 @@ void TestInference(const std::string& dirname,
165169

166170
// 6. Run the inference program
167171
{
172+
// Ignore the profiling results of the first run
173+
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
174+
175+
// Enable the profiler
176+
paddle::platform::EnableProfiler(state);
177+
168178
// Run repeat times to profile the performance
169179
for (int i = 0; i < repeat; ++i) {
170180
paddle::platform::RecordEvent record_event(
@@ -173,12 +183,13 @@ void TestInference(const std::string& dirname,
173183

174184
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
175185
}
176-
}
177186

178-
// Disable the profiler and print the timing information
179-
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,
180-
"profiler.txt");
181-
paddle::platform::ResetProfiler();
187+
// Disable the profiler and print the timing information
188+
paddle::platform::DisableProfiler(
189+
paddle::platform::EventSortingKey::kDefault,
190+
"run_inference_profiler.txt");
191+
paddle::platform::ResetProfiler();
192+
}
182193

183194
delete scope;
184195
}

0 commit comments

Comments
 (0)