Skip to content

Commit 1d19849

Browse files
authored
Merge pull request #11370 from panyx0718/dist
Make status update thread-safe
2 parents 183377f + 1509ae3 commit 1d19849

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

paddle/fluid/operators/detail/grpc_server.cc

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,22 @@ class RequestBase {
4141
virtual ~RequestBase() {}
4242
virtual void Process() = 0;
4343

44-
CallStatus Status() { return status_; }
45-
void SetStatus(CallStatus status) { status_ = status; }
44+
CallStatus Status() const {
45+
std::lock_guard<std::mutex> l(status_mu_);
46+
return status_;
47+
}
48+
49+
template <typename T>
50+
void Finish(const T& reply, ServerAsyncResponseWriter<T>* responder) {
51+
std::lock_guard<std::mutex> l(status_mu_);
52+
status_ = FINISH;
53+
responder->Finish(reply, ::grpc::Status::OK,
54+
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
55+
}
4656
virtual std::string GetReqName() = 0;
4757

4858
protected:
59+
mutable std::mutex status_mu_;
4960
::grpc::ServerContext ctx_;
5061
GrpcService::AsyncService* service_;
5162
::grpc::ServerCompletionQueue* cq_;
@@ -80,9 +91,7 @@ class RequestSend final : public RequestBase {
8091
framework::Variable* outvar = nullptr;
8192

8293
request_handler_->Handle(varname, scope, invar, &outvar);
83-
status_ = FINISH;
84-
responder_.Finish(reply_, ::grpc::Status::OK,
85-
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
94+
Finish(reply_, &responder_);
8695
}
8796

8897
protected:
@@ -122,9 +131,7 @@ class RequestGet final : public RequestBase {
122131
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
123132
&reply_);
124133
}
125-
status_ = FINISH;
126-
responder_.Finish(reply_, ::grpc::Status::OK,
127-
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
134+
Finish(reply_, &responder_);
128135
}
129136

130137
protected:
@@ -166,9 +173,7 @@ class RequestPrefetch final : public RequestBase {
166173

167174
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
168175
&reply_);
169-
responder_.Finish(reply_, ::grpc::Status::OK,
170-
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
171-
status_ = FINISH;
176+
Finish(reply_, &responder_);
172177
}
173178

174179
protected:

paddle/fluid/operators/detail/grpc_server.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class AsyncGRPCServer final : public RPCServer {
5353
void StartServer() override;
5454

5555
private:
56+
// HandleRequest needs to be thread-safe.
5657
void HandleRequest(
5758
::grpc::ServerCompletionQueue* cq, const std::string& rpc_name,
5859
std::function<void(const std::string&, int)> TryToRegisterNewOne);

0 commit comments

Comments
 (0)