Skip to content

Commit 58c027c

Browse files
authored
Add rpc profiler flags. (#13989)
Add rpc profiler flags
1 parent d10e54c commit 58c027c

File tree

5 files changed

+29
-9
lines changed

5 files changed

+29
-9
lines changed

paddle/fluid/operators/distributed/grpc_client.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
8686
// stub context
8787
s->response_call_back_ = nullptr;
8888

89-
platform::RecordEvent record_event(method, p_ctx);
89+
platform::RecordRPCEvent record_event(method, p_ctx);
9090

9191
auto call = s->stub_g_.PrepareUnaryCall(
9292
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
@@ -143,7 +143,7 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
143143
// stub context
144144
s->response_call_back_ = ProcGetResponse;
145145

146-
platform::RecordEvent record_event(method, p_ctx);
146+
platform::RecordRPCEvent record_event(method, p_ctx);
147147

148148
auto call = s->stub_g_.PrepareUnaryCall(
149149
s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
@@ -191,7 +191,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
191191
// stub context
192192
s->response_call_back_ = ProcGetResponse;
193193

194-
platform::RecordEvent record_event(method, p_ctx);
194+
platform::RecordRPCEvent record_event(method, p_ctx);
195195

196196
auto call = s->stub_g_.PrepareUnaryCall(
197197
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
@@ -221,7 +221,7 @@ VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
221221
sendrecv::VariableMessage req;
222222
req.set_varname(BATCH_BARRIER_MESSAGE);
223223

224-
platform::RecordEvent record_event(method, nullptr);
224+
platform::RecordRPCEvent record_event(method, nullptr);
225225

226226
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
227227
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
@@ -246,7 +246,7 @@ VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
246246
sendrecv::VariableMessage req;
247247
req.set_varname(FETCH_BARRIER_MESSAGE);
248248

249-
platform::RecordEvent record_event(method, nullptr);
249+
platform::RecordRPCEvent record_event(method, nullptr);
250250

251251
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
252252
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
@@ -271,7 +271,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
271271
sendrecv::VariableMessage req;
272272
req.set_varname(COMPLETE_MESSAGE);
273273

274-
platform::RecordEvent record_event(method, nullptr);
274+
platform::RecordRPCEvent record_event(method, nullptr);
275275

276276
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
277277
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
@@ -301,7 +301,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
301301
req.set_varname(CHECKPOINT_SAVE_MESSAGE);
302302
req.set_out_varname(dir);
303303

304-
platform::RecordEvent record_event(method, nullptr);
304+
platform::RecordRPCEvent record_event(method, nullptr);
305305

306306
auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
307307
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));

paddle/fluid/operators/distributed/grpc_serde.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
3636
const platform::DeviceContext& ctx,
3737
::grpc::ByteBuffer* msg,
3838
const std::string& out_name) {
39-
platform::RecordEvent record_event("serial", &ctx);
39+
platform::RecordRPCEvent record_event("serial", &ctx);
4040
// Default DestroyCallback does nothing, When using GPU
4141
// the CPU buffer need to be freed.
4242
DestroyCallback destroy_callback = [](void* backing) {};
@@ -148,7 +148,7 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
148148
const platform::DeviceContext& ctx,
149149
const framework::Scope* scope,
150150
framework::Variable** var) {
151-
platform::RecordEvent record_event("deserial", &ctx);
151+
platform::RecordRPCEvent record_event("deserial", &ctx);
152152
operators::distributed::GRPCVariableResponse resp(scope, &ctx);
153153
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
154154
*var = resp.GetVar();

paddle/fluid/platform/profiler.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ limitations under the License. */
3030
#include "paddle/fluid/platform/device_tracer.h"
3131
#include "paddle/fluid/string/printf.h"
3232

33+
DEFINE_bool(enable_rpc_profiler, false, "Enable rpc profiler or not.");
34+
3335
namespace paddle {
3436
namespace platform {
3537

@@ -193,6 +195,13 @@ RecordEvent::~RecordEvent() {
193195
PopEvent(name_, dev_ctx_);
194196
}
195197

198+
RecordRPCEvent::RecordRPCEvent(const std::string& name,
199+
const DeviceContext* dev_ctx) {
200+
if (FLAGS_enable_rpc_profiler) {
201+
event_.reset(new platform::RecordEvent(name, dev_ctx));
202+
}
203+
}
204+
196205
RecordBlock::RecordBlock(int block_id)
197206
: is_enabled_(false), start_ns_(PosixInNsec()) {
198207
std::lock_guard<std::mutex> l(profiler_mu);

paddle/fluid/platform/profiler.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,16 @@ struct RecordEvent {
8787
std::string full_name_;
8888
};
8989

90+
class RecordRPCEvent {
91+
public:
92+
// dev_ctx can be set to nullptr if device is cpu.
93+
RecordRPCEvent(const std::string& name, const DeviceContext* dev_ctx);
94+
~RecordRPCEvent() {}
95+
96+
private:
97+
std::unique_ptr<RecordEvent> event_;
98+
};
99+
90100
struct RecordBlock {
91101
explicit RecordBlock(int block_id);
92102
~RecordBlock();

python/paddle/fluid/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def __bootstrap__():
120120
read_env_flags.append('rpc_deadline')
121121
read_env_flags.append('rpc_server_profile_period')
122122
read_env_flags.append('rpc_server_profile_path')
123+
read_env_flags.append('enable_rpc_profiler')
123124

124125
if core.is_compiled_with_cuda():
125126
read_env_flags += [

0 commit comments

Comments
 (0)