Skip to content

Commit 8fb78f6

Browse files
committed
fix grpc_server_test
1 parent 4e36c0e commit 8fb78f6

File tree

4 files changed

+19
-13
lines changed

4 files changed

+19
-13
lines changed

paddle/fluid/operators/detail/grpc_server.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,13 +158,14 @@ class RequestPrefetch final : public RequestBase {
158158
std::string in_var_name = request_->Varname();
159159
std::string out_var_name = request_->OutVarname();
160160
VLOG(3) << "in_var_name: " << in_var_name
161+
<< "out_var_name: " << out_var_name
161162
<< " RequestPrefetch: " << out_var_name;
162163

163164
auto scope = request_->GetMutableLocalScope();
164165
auto invar = scope->FindVar(in_var_name);
165-
framework::Variable* outvar = nullptr;
166+
framework::Variable* outvar = scope->FindVar(out_var_name);
166167

167-
request_handler_->Handle(in_var_name, scope, invar, &outvar);
168+
request_handler_->Handle(in_var_name, scope, invar, &outvar, out_var_name);
168169

169170
SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
170171
&reply_);
@@ -284,7 +285,7 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
284285
} else if (rpc_name == kRequestPrefetch) {
285286
b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
286287
} else {
287-
PADDLE_ENFORCE(false, "not surpported rpc");
288+
PADDLE_ENFORCE(false, "not supported rpc");
288289
}
289290

290291
reqs[req_id] = b;

paddle/fluid/operators/detail/request_handler.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ class RequestHandler {
9696
// *request_handler_->dev_ctx(), &reply_);
9797
// }
9898
virtual bool Handle(const std::string& varname, framework::Scope* scope,
99-
framework::Variable* var,
100-
framework::Variable** outvar) = 0;
99+
framework::Variable* var, framework::Variable** outvar,
100+
const std::string& out_var_name = "") = 0;
101101

102102
protected:
103103
const bool sync_mode_;

paddle/fluid/operators/detail/request_handler_impl.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ namespace detail {
3333
bool RequestSendHandler::Handle(const std::string& varname,
3434
framework::Scope* scope,
3535
framework::Variable* invar,
36-
framework::Variable** outvar) {
36+
framework::Variable** outvar,
37+
const std::string& out_var_name) {
3738
VLOG(4) << "RequestSendHandler:" << varname;
3839

3940
// Async
@@ -82,7 +83,8 @@ void RequestSendHandler::ResetSparseVarRecorder() {
8283
bool RequestGetHandler::Handle(const std::string& varname,
8384
framework::Scope* scope,
8485
framework::Variable* invar,
85-
framework::Variable** outvar) {
86+
framework::Variable** outvar,
87+
const std::string& out_var_name) {
8688
VLOG(4) << "RequestGetHandler:" << varname;
8789

8890
if (varname != FETCH_BARRIER_MESSAGE) {
@@ -105,11 +107,11 @@ bool RequestGetHandler::Handle(const std::string& varname,
105107
bool RequestPrefetchHandler::Handle(const std::string& varname,
106108
framework::Scope* scope,
107109
framework::Variable* invar,
108-
framework::Variable** outvar) {
110+
framework::Variable** outvar,
111+
const std::string& out_var_name) {
109112
VLOG(4) << "RequestPrefetchHandler " << varname;
110113

111-
auto var_desc = program_->Block(0).FindVar(varname);
112-
*outvar = scope->FindVar(varname);
114+
auto var_desc = program_->Block(0).FindVar(out_var_name);
113115
InitializeVariable(*outvar, var_desc->GetType());
114116
executor_->RunPreparedContext(
115117
(*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope);

paddle/fluid/operators/detail/request_handler_impl.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ class RequestSendHandler final : public RequestHandler {
4040
explicit RequestSendHandler(bool sync_mode) : RequestHandler(sync_mode) {}
4141
virtual ~RequestSendHandler() {}
4242
bool Handle(const std::string& varname, framework::Scope* scope,
43-
framework::Variable* var, framework::Variable** outvar) override;
43+
framework::Variable* var, framework::Variable** outvar,
44+
const std::string& out_var_name = "") override;
4445
void ResetSparseVarRecorder();
4546

4647
private:
@@ -53,15 +54,17 @@ class RequestGetHandler final : public RequestHandler {
5354
explicit RequestGetHandler(bool sync_mode) : RequestHandler(sync_mode) {}
5455
virtual ~RequestGetHandler() {}
5556
bool Handle(const std::string& varname, framework::Scope* scope,
56-
framework::Variable* var, framework::Variable** outvar) override;
57+
framework::Variable* var, framework::Variable** outvar,
58+
const std::string& out_var_name = "") override;
5759
};
5860

5961
class RequestPrefetchHandler final : public RequestHandler {
6062
public:
6163
explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {}
6264
virtual ~RequestPrefetchHandler() {}
6365
bool Handle(const std::string& varname, framework::Scope* scope,
64-
framework::Variable* var, framework::Variable** outvar) override;
66+
framework::Variable* var, framework::Variable** outvar,
67+
const std::string& out_var_name = "") override;
6568
};
6669

6770
} // namespace detail

0 commit comments

Comments
 (0)