Skip to content

Commit 052c05b

Browse files
authored
Merge pull request #7537 from Yancey1989/distributed_gpu
Fluid distributed supports CUDA place
2 parents 5f44813 + 329f1e0 commit 052c05b

File tree

5 files changed

+23
-14
lines changed

5 files changed

+23
-14
lines changed

paddle/framework/tensor_util.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,9 +315,8 @@ inline void DeserializeFromStream(std::istream& is, Tensor* tensor,
315315
desc.data_type(),
316316
DeserializedDataFunctor(&buf, &cpu_tensor, ctx.GetPlace()));
317317
is.read(static_cast<char*>(buf), cpu_tensor.memory_size());
318-
auto cpu_place = new platform::CPUPlace();
319-
framework::Copy(cpu_tensor, *cpu_place, dev_ctx, tensor);
320-
delete cpu_place;
318+
auto dst_place = dev_ctx.GetPlace();
319+
framework::Copy(cpu_tensor, dst_place, dev_ctx, tensor);
321320
#else
322321
PADDLE_THROW("Unexpected branch");
323322
#endif

paddle/operators/detail/grpc_server.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,12 @@ class RequestSend final : public RequestBase {
7979
class RequestGet final : public RequestBase {
8080
public:
8181
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
82-
grpc::ServerCompletionQueue* cq, framework::Scope* scope)
83-
: RequestBase(service, cq), responder_(&ctx_), scope_(scope) {
82+
grpc::ServerCompletionQueue* cq, framework::Scope* scope,
83+
const platform::DeviceContext* dev_ctx)
84+
: RequestBase(service, cq),
85+
responder_(&ctx_),
86+
scope_(scope),
87+
dev_ctx_(dev_ctx) {
8488
service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this);
8589
}
8690

@@ -92,7 +96,7 @@ class RequestGet final : public RequestBase {
9296
// proc request.
9397
std::string var_name = request_.varname();
9498
auto* var = scope_->FindVar(var_name);
95-
SerializeToMessage(var_name, var, platform::CPUDeviceContext(), &reply_);
99+
SerializeToMessage(var_name, var, *dev_ctx_, &reply_);
96100
// TODO(gongwb): check var's info.
97101
responder_.Finish(reply_, grpc::Status::OK, this);
98102
status_ = FINISH;
@@ -103,6 +107,7 @@ class RequestGet final : public RequestBase {
103107
sendrecv::VariableMessage reply_;
104108
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
105109
framework::Scope* scope_;
110+
const platform::DeviceContext* dev_ctx_;
106111
};
107112

108113
void AsyncGRPCServer::RunSyncUpdate() {
@@ -165,7 +170,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
165170
if (is_shut_down_) {
166171
return;
167172
}
168-
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_);
173+
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_);
169174
VLOG(4) << "create Requestget status:" << get->Status();
170175
}
171176

paddle/operators/detail/grpc_server.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class RequestBase;
3737

3838
class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
3939
public:
40-
explicit AsyncGRPCServer(std::string address) { address_ = address; }
40+
explicit AsyncGRPCServer(const std::string &address) : address_(address) {}
4141

4242
void RunSyncUpdate();
4343

@@ -47,6 +47,8 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
4747

4848
void SetScope(framework::Scope *scope) { scope_ = scope; }
4949

50+
void SetDevCtx(const platform::DeviceContext *dev_ctx) { dev_ctx_ = dev_ctx; }
51+
5052
const MessageWithName Get() { return this->var_recv_queue_.Pop(); }
5153

5254
void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); }
@@ -73,6 +75,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
7375

7476
std::string address_;
7577
framework::Scope *scope_;
78+
const platform::DeviceContext *dev_ctx_;
7679
// received variable from RPC, operators fetch variable from this queue.
7780
SimpleBlockQueue<MessageWithName> var_recv_queue_;
7881

paddle/operators/recv_op.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,12 @@ class RecvOp : public framework::OperatorBase {
8787
const platform::Place &dev_place) const override {
8888
// FIXME(typhoonzero): no new scopes for every run.
8989
framework::Scope &recv_scope = scope.NewScope();
90+
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
91+
auto &dev_ctx = *pool.Get(dev_place);
92+
93+
// FIXME(Yancey1989): initialize rpc server with laze mode.
9094
rpc_service_->SetScope(&recv_scope);
95+
rpc_service_->SetDevCtx(&dev_ctx);
9196
auto param_list = Attr<std::vector<std::string>>("ParamList");
9297
auto grad_list = Attr<std::vector<std::string>>("GradList");
9398
auto trainer_count = Attr<int>("Trainers");
@@ -136,9 +141,6 @@ class RecvOp : public framework::OperatorBase {
136141
}
137142

138143
auto *var = recv_scope.Var(grad_var_name);
139-
platform::DeviceContextPool &pool =
140-
platform::DeviceContextPool::Instance();
141-
auto &dev_ctx = *pool.Get(dev_place);
142144
detail::DeserializeFromMessage(v.second, dev_ctx, var);
143145
}
144146

paddle/operators/send_op.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@ class SendOp : public framework::OperatorBase {
3333
: OperatorBase(type, inputs, outputs, attrs) {}
3434

3535
void Run(const framework::Scope& scope,
36-
const platform::Place& dev_place) const override {
36+
const platform::Place& place) const override {
3737
auto ins = Inputs("X");
3838
auto outs = Outputs("Out");
3939
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
4040

41-
// FIXME(gongwb): DeviceContext?
42-
auto ctx = platform::CPUDeviceContext();
41+
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
42+
auto& ctx = *pool.Get(place);
4343
for (size_t i = 0; i < ins.size(); i++) {
4444
client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
4545
}

0 commit comments

Comments
 (0)