Skip to content

Commit 6a4e923

Browse files
committed
Merge branch 'develop' into mkldnn_test
2 parents b819684 + 078223b commit 6a4e923

24 files changed

+689
-170
lines changed

benchmark/fluid/run.sh

100644100755
File mode changed.

paddle/fluid/framework/op_desc.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -100,16 +100,6 @@ class OpDesc {
100100
std::vector<std::string> InputNames() const { return MapKeys(inputs_); }
101101
std::vector<std::string> OutputNames() const { return MapKeys(outputs_); }
102102

103-
void SetInputMap(const VariableNameMap &input) {
104-
this->inputs_ = input;
105-
this->need_update_ = true;
106-
}
107-
108-
void SetOutputMap(const VariableNameMap &output) {
109-
this->outputs_ = output;
110-
this->need_update_ = true;
111-
}
112-
113103
const VariableNameMap &Inputs() const { return inputs_; }
114104

115105
const VariableNameMap &Outputs() const { return outputs_; }

paddle/fluid/inference/analysis/analyzer_tester.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,7 @@ void TestWord2vecPrediction(const std::string& model_path) {
5151
config.model_dir = model_path;
5252
config.use_gpu = false;
5353
config.device = 0;
54-
auto predictor =
55-
::paddle::CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(
56-
config);
54+
auto predictor = ::paddle::CreatePaddlePredictor<NativeConfig>(config);
5755

5856
// One single batch
5957

paddle/fluid/inference/api/analysis_predictor_tester.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@ TEST(AnalysisPredictor, ZeroCopy) {
2727
config.model_dir = FLAGS_dirname + "/word2vec.inference.model";
2828
config.use_feed_fetch_ops = false;
2929

30-
auto predictor =
31-
CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
32-
config);
30+
auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);
3331

3432
auto w0 = predictor->GetInputTensor("firstw");
3533
auto w1 = predictor->GetInputTensor("secondw");

paddle/fluid/inference/api/api_tensorrt_subgraph_engine_tester.cc

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,8 @@ void CompareTensorRTWithFluid(bool enable_tensorrt) {
4141
config1.device = 0;
4242
config1.max_batch_size = 10;
4343

44-
auto predictor0 =
45-
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config0);
46-
auto predictor1 =
47-
CreatePaddlePredictor<MixedRTConfig,
48-
PaddleEngineKind::kAutoMixedTensorRT>(config1);
44+
auto predictor0 = CreatePaddlePredictor<NativeConfig>(config0);
45+
auto predictor1 = CreatePaddlePredictor<MixedRTConfig>(config1);
4946

5047
for (int batch_id = 0; batch_id < 1; batch_id++) {
5148
//# 2. Prepare input.

paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -308,18 +308,14 @@ TEST(Analyzer_rnn1, ZeroCopy) {
308308
PaddlePlace place;
309309
int output_size{0};
310310

311-
auto predictor =
312-
CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
313-
config);
311+
auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);
314312

315313
config.use_feed_fetch_ops = true;
316314
auto native_predictor =
317315
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
318316

319317
config.use_feed_fetch_ops = true; // the analysis predictor needs feed/fetch.
320-
auto analysis_predictor =
321-
CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
322-
config);
318+
auto analysis_predictor = CreatePaddlePredictor<AnalysisConfig>(config);
323319

324320
#define NEW_TENSOR(name__) \
325321
auto name__##_tensor = predictor->GetInputTensor(#name__);

paddle/fluid/inference/tests/api/tester_helper.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,7 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
7979
std::unique_ptr<PaddlePredictor> CreateTestPredictor(
8080
const AnalysisConfig &config, bool use_analysis = true) {
8181
if (use_analysis) {
82-
return CreatePaddlePredictor<contrib::AnalysisConfig,
83-
PaddleEngineKind::kAnalysis>(config);
82+
return CreatePaddlePredictor<contrib::AnalysisConfig>(config);
8483
} else {
8584
return CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(
8685
config);

paddle/fluid/inference/tests/api/trt_models_tester.cc

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,8 @@ void CompareTensorRTWithFluid(int batch_size, std::string model_dirname) {
5151
config1.model_dir = model_dirname;
5252
config1.max_batch_size = batch_size;
5353

54-
auto predictor0 =
55-
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config0);
56-
auto predictor1 =
57-
CreatePaddlePredictor<MixedRTConfig,
58-
PaddleEngineKind::kAutoMixedTensorRT>(config1);
54+
auto predictor0 = CreatePaddlePredictor<NativeConfig>(config0);
55+
auto predictor1 = CreatePaddlePredictor<MixedRTConfig>(config1);
5956
// Prepare inputs
6057
int height = 224;
6158
int width = 224;

paddle/fluid/operators/distributed/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ if(WITH_GRPC)
2020
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
2121
cc_test(rpc_server_test SRCS rpc_server_test.cc
2222
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
23-
cc_test(varhandle_test SRCS varhandle_test.cc)
23+
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler)
2424
return()
2525
endif()
2626

paddle/fluid/operators/distributed/grpc_client.cc

Lines changed: 74 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,11 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
7373
const framework::Scope* p_scope = &scope;
7474
const auto ch = GetChannel(ep_val);
7575
SendProcessor* s = new SendProcessor(ch);
76-
VarHandlePtr h(new VarHandle(ep, "Send", var_name_val, p_ctx, p_scope));
76+
const std::string method = "SendRPC";
77+
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
7778
s->Prepare(h, time_out);
7879

79-
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] {
80+
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
8081
auto* var = p_scope->FindVar(var_name_val);
8182

8283
::grpc::ByteBuffer req;
@@ -87,10 +88,16 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
8788
// stub context
8889
s->response_call_back_ = nullptr;
8990

91+
platform::RecordEvent record_event(method, p_ctx);
92+
9093
auto call = s->stub_g_.PrepareUnaryCall(
9194
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
9295
call->StartCall();
9396
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
97+
98+
if (UNLIKELY(platform::IsProfileEnabled())) {
99+
h->Wait();
100+
}
94101
});
95102
req_count_++;
96103

@@ -122,10 +129,11 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
122129
const framework::Scope* p_scope = &scope;
123130
const auto ch = GetChannel(ep_val);
124131
GetProcessor* s = new GetProcessor(ch);
125-
VarHandlePtr h(new VarHandle(ep, "Get", var_name_val, p_ctx, p_scope));
132+
const std::string method = "GetRPC";
133+
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
126134
s->Prepare(h, time_out);
127135

128-
framework::AsyncIO([var_name_val, s, this] {
136+
framework::AsyncIO([var_name_val, s, method, p_ctx, h, this] {
129137
// prepare input
130138
sendrecv::VariableMessage req;
131139
req.set_varname(var_name_val);
@@ -137,10 +145,16 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
137145
// stub context
138146
s->response_call_back_ = ProcGetResponse;
139147

148+
platform::RecordEvent record_event(method, p_ctx);
149+
140150
auto call = s->stub_g_.PrepareUnaryCall(
141151
s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
142152
call->StartCall();
143153
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
154+
155+
if (UNLIKELY(platform::IsProfileEnabled())) {
156+
h->Wait();
157+
}
144158
});
145159

146160
req_count_++;
@@ -161,12 +175,14 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
161175
const framework::Scope* p_scope = &scope;
162176
const auto ch = GetChannel(ep_val);
163177
GetProcessor* s = new GetProcessor(ch);
164-
VarHandlePtr h(
165-
new VarHandle(ep, "Prefetch", out_var_name_val, p_ctx, p_scope));
178+
179+
const std::string method = "PrefetchRPC";
180+
181+
VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
166182
s->Prepare(h, time_out);
167183

168184
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
169-
s, this] {
185+
s, method, h, this] {
170186
auto* var = p_scope->FindVar(in_var_name_val);
171187

172188
::grpc::ByteBuffer req;
@@ -177,11 +193,17 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
177193
// stub context
178194
s->response_call_back_ = ProcGetResponse;
179195

196+
platform::RecordEvent record_event(method, p_ctx);
197+
180198
auto call = s->stub_g_.PrepareUnaryCall(
181199
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
182200
&cq_);
183201
call->StartCall();
184202
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
203+
204+
if (UNLIKELY(platform::IsProfileEnabled())) {
205+
h->Wait();
206+
}
185207
});
186208

187209
req_count_++;
@@ -193,31 +215,49 @@ VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
193215
const auto ch = GetChannel(ep);
194216

195217
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
196-
VarHandlePtr h(new VarHandle(ep, "BatchBarrier", BATCH_BARRIER_MESSAGE,
197-
nullptr, nullptr));
218+
const std::string method = "BatchBarrierRPC";
219+
VarHandlePtr h(
220+
new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
198221
s->Prepare(h, time_out);
199222

200223
sendrecv::VariableMessage req;
201224
req.set_varname(BATCH_BARRIER_MESSAGE);
225+
226+
platform::RecordEvent record_event(method, nullptr);
227+
202228
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
203229
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
204230
req_count_++;
231+
232+
if (UNLIKELY(platform::IsProfileEnabled())) {
233+
h->Wait();
234+
}
235+
205236
return h;
206237
}
207238

208239
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
209240
int64_t time_out) {
210241
const auto ch = GetChannel(ep);
211242
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
212-
VarHandlePtr h(new VarHandle(ep, "FetchBarrier", FETCH_BARRIER_MESSAGE,
213-
nullptr, nullptr));
243+
const std::string method = "FetchBarrierRPC";
244+
VarHandlePtr h(
245+
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
214246
s->Prepare(h, time_out);
215247

216248
sendrecv::VariableMessage req;
217249
req.set_varname(FETCH_BARRIER_MESSAGE);
250+
251+
platform::RecordEvent record_event(method, nullptr);
252+
218253
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
219254
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
220255
req_count_++;
256+
257+
if (UNLIKELY(platform::IsProfileEnabled())) {
258+
h->Wait();
259+
}
260+
221261
return h;
222262
}
223263

@@ -226,15 +266,23 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
226266
const auto ch = GetChannel(ep);
227267

228268
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
229-
VarHandlePtr h(
230-
new VarHandle(ep, "SendComplete", COMPLETE_MESSAGE, nullptr, nullptr));
269+
const std::string method = "SendCompleteRPC";
270+
VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
231271
s->Prepare(h, time_out);
232272

233273
sendrecv::VariableMessage req;
234274
req.set_varname(COMPLETE_MESSAGE);
275+
276+
platform::RecordEvent record_event(method, nullptr);
277+
235278
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
236279
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
237280
req_count_++;
281+
282+
if (UNLIKELY(platform::IsProfileEnabled())) {
283+
h->Wait();
284+
}
285+
238286
return h;
239287
}
240288

@@ -244,17 +292,27 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
244292
const auto ch = GetChannel(ep);
245293

246294
CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
247-
VarHandlePtr h(new VarHandle(ep, "CheckPointNotify", CHECKPOINT_SAVE_MESSAGE,
248-
nullptr, nullptr));
295+
296+
const std::string method = "CheckPointNotifyRPC";
297+
298+
VarHandlePtr h(
299+
new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
249300
s->Prepare(h, time_out);
250301

251302
sendrecv::VariableMessage req;
252303
req.set_varname(CHECKPOINT_SAVE_MESSAGE);
253304
req.set_out_varname(dir);
254305

306+
platform::RecordEvent record_event(method, nullptr);
307+
255308
auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
256309
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
257310
req_count_++;
311+
312+
if (UNLIKELY(platform::IsProfileEnabled())) {
313+
h->Wait();
314+
}
315+
258316
return h;
259317
}
260318

@@ -273,6 +331,7 @@ void GRPCClient::Proceed() {
273331
BaseProcessor* c = static_cast<BaseProcessor*>(tag);
274332
GPR_ASSERT(ok);
275333
PADDLE_ENFORCE(c);
334+
276335
if (c->status_.ok()) {
277336
VLOG(3) << c->GetVarHandlePtr()->String() << " process";
278337
c->Process();

0 commit comments

Comments
 (0)