Skip to content

Commit 6ab935f

Browse files
authored
Merge pull request #10349 from typhoonzero/gen_nccl_id_op
[Feature] NCCL2 distributed training
2 parents ca5ea65 + 872e55b commit 6ab935f

File tree

15 files changed

+339
-24
lines changed

15 files changed

+339
-24
lines changed

paddle/fluid/framework/parallel_executor.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ ParallelExecutor::ParallelExecutor(
5858
const std::unordered_set<std::string> &bcast_vars,
5959
const ProgramDesc &main_program, const std::string &loss_var_name,
6060
Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay,
61-
bool use_default_grad_scale, bool balance_parameter_opt_between_cards)
61+
bool use_default_grad_scale, bool balance_parameter_opt_between_cards,
62+
size_t num_trainers, size_t trainer_id)
6263
: member_(new ParallelExecutorPrivate(places)) {
6364
member_->global_scope_ = scope;
6465

@@ -80,7 +81,13 @@ ParallelExecutor::ParallelExecutor(
8081

8182
// Bcast Parameters to all GPUs
8283
#ifdef PADDLE_WITH_CUDA
83-
member_->nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_));
84+
auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
85+
ncclUniqueId *nccl_id = nullptr;
86+
if (nccl_id_var != nullptr) {
87+
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
88+
}
89+
member_->nccl_ctxs_.reset(new platform::NCCLContextMap(
90+
member_->places_, nccl_id, num_trainers, trainer_id));
8491
#endif
8592
if (platform::is_gpu_place(places[0]) && member_->local_scopes_.size() != 1 &&
8693
local_scopes.empty()) { // Is CUDA

paddle/fluid/framework/parallel_executor.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ class ParallelExecutor {
4141
const std::string& loss_var_name, Scope* scope,
4242
const std::vector<Scope*>& local_scopes,
4343
bool allow_op_delay, bool use_default_grad_scale,
44-
bool balance_parameter_opt_between_cards);
44+
bool balance_parameter_opt_between_cards,
45+
size_t num_trainers = 1, size_t trainer_id = 0);
4546

4647
~ParallelExecutor();
4748

paddle/fluid/operators/CMakeLists.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,11 @@ endif()
186186

187187
add_subdirectory(detail)
188188
if(WITH_DISTRIBUTE)
189+
if(WITH_GPU)
190+
op_library(gen_nccl_id_op DEPS nccl_common)
191+
else()
192+
set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
193+
endif()
189194
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
190195
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
191196
op_library(send_op DEPS ${DISTRIBUTE_DEPS})
@@ -202,8 +207,9 @@ if(WITH_DISTRIBUTE)
202207
set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
203208
set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
204209
cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op listen_and_serv_op sum_op executor)
210+
cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS send_op listen_and_serv_op executor)
205211
else()
206-
set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op)
212+
set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op gen_nccl_id_op)
207213
endif()
208214

209215
op_library(cross_entropy_op DEPS cross_entropy)

paddle/fluid/operators/detail/grpc_client.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
5252
// stub context
5353
SendProcessor* s = new SendProcessor(ch);
5454
s->Prepare(var_h, time_out);
55-
s->response_call_back_ = NULL;
55+
s->response_call_back_ = nullptr;
5656

5757
auto call = s->stub_g_.PrepareUnaryCall(
5858
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);

paddle/fluid/operators/detail/grpc_client.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
5757

5858
class BaseProcessor {
5959
public:
60-
explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) { context_ = NULL; }
60+
explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) {
61+
context_ = nullptr;
62+
}
6163

6264
virtual ~BaseProcessor() {}
6365

@@ -105,7 +107,7 @@ class SendProcessor : public BaseProcessor {
105107

106108
::grpc::GenericStub stub_g_;
107109
::grpc::ByteBuffer reply_;
108-
RequestSendCallBack response_call_back_ = NULL;
110+
RequestSendCallBack response_call_back_ = nullptr;
109111
};
110112

111113
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>

paddle/fluid/operators/detail/grpc_server.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class AsyncGRPCServer final {
4747
explicit AsyncGRPCServer(const std::string &address, bool sync_mode)
4848
: address_(address), sync_mode_(sync_mode), ready_(0) {}
4949

50+
~AsyncGRPCServer() {}
5051
void WaitServerReady();
5152
void RunSyncUpdate();
5253

paddle/fluid/operators/detail/send_recv.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ service SendRecvService {
3232
enum VarType {
3333
LOD_TENSOR = 0;
3434
SELECTED_ROWS = 1;
35+
NCCL_ID = 2;
3536
}
3637

3738
// NOTICE(gongwb):don't modify this proto if you are not

paddle/fluid/operators/detail/sendrecvop_utils.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
1616

17+
#ifdef PADDLE_WITH_CUDA
18+
#include <nccl.h>
19+
#endif
1720
#include <sys/time.h>
1821
#include <thread> // NOLINT
1922

@@ -129,6 +132,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
129132
} else if (var->IsType<framework::SelectedRows>()) {
130133
request.set_type(::sendrecv::SELECTED_ROWS);
131134
GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size);
135+
#ifdef PADDLE_WITH_CUDA
136+
} else if (var->IsType<ncclUniqueId>()) {
137+
request.set_type(::sendrecv::NCCL_ID);
138+
#endif
132139
} else {
133140
PADDLE_THROW("Serialize does not support type: %s",
134141
typeid(var->Type()).name());
@@ -149,6 +156,24 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
149156
void* buf = buffer.get();
150157
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
151158
e.WriteRawBytes(std::string(header.data(), header.size()));
159+
// NCCLID is copied directly to the message, return bytebuffer
160+
// with only one slice if serializing NCCLID.
161+
#ifdef PADDLE_WITH_CUDA
162+
if (var->IsType<ncclUniqueId>()) {
163+
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
164+
NCCL_UNIQUE_ID_BYTES);
165+
const ncclUniqueId& uid = var->Get<ncclUniqueId>();
166+
e.WriteRawBytes(std::string(uid.internal, NCCL_UNIQUE_ID_BYTES));
167+
168+
// for serialize NCCL_ID
169+
::grpc::Slice slices(e.size());
170+
memcpy(const_cast<uint8_t*>(slices.begin()), e.data(), e.size());
171+
::grpc::ByteBuffer tmp(&slices, 1);
172+
msg->Swap(&tmp);
173+
return;
174+
}
175+
#endif
176+
152177
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
153178
// steal reference of tensor data
154179
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows

paddle/fluid/operators/detail/variable_response.cc

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
#include <string>
1818
#include <utility>
1919
#include <vector>
20+
#ifdef PADDLE_WITH_CUDA
21+
#include <nccl.h>
22+
#endif
2023
#include "paddle/fluid/platform/profiler.h"
2124

2225
#include "paddle/fluid/operators/detail/send_recv.pb.h"
@@ -368,7 +371,8 @@ int VariableResponse::Parse(Source* source) {
368371
}
369372
case sendrecv::VariableMessage::kSerializedFieldNumber: {
370373
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
371-
meta_.type() == sendrecv::LOD_TENSOR) &&
374+
meta_.type() == sendrecv::LOD_TENSOR ||
375+
meta_.type() == sendrecv::NCCL_ID) &&
372376
meta_.varname() != "",
373377
"meta info should be got first!");
374378

@@ -378,6 +382,22 @@ int VariableResponse::Parse(Source* source) {
378382
return tag;
379383
}
380384

385+
if (meta_.type() == sendrecv::NCCL_ID) {
386+
#ifdef PADDLE_WITH_CUDA
387+
auto* var = scope_->FindVar(meta_.varname());
388+
if (var != nullptr) {
389+
ncclUniqueId* id = var->GetMutable<ncclUniqueId>();
390+
if (!ReadRaw(&input, *dev_ctx_, platform::CPUPlace(), id->internal,
391+
num_bytes)) {
392+
return tag;
393+
}
394+
}
395+
break;
396+
#else
397+
PADDLE_THROW("Not compiled with CUDA!");
398+
#endif
399+
}
400+
381401
framework::DDim dims = GetDims(meta_.dims());
382402
if (meta_.type() == sendrecv::LOD_TENSOR) {
383403
PADDLE_ENFORCE(meta_.lod_size() >= 0,
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <nccl.h>
16+
#include <stdint.h>
17+
#include <ostream>
18+
#include <string>
19+
20+
#include "paddle/fluid/framework/executor.h"
21+
#include "paddle/fluid/framework/lod_tensor.h"
22+
#include "paddle/fluid/framework/op_registry.h"
23+
#include "paddle/fluid/framework/threadpool.h"
24+
#include "paddle/fluid/operators/detail/grpc_client.h"
25+
#include "paddle/fluid/operators/detail/grpc_server.h"
26+
#include "paddle/fluid/platform/nccl_helper.h"
27+
28+
namespace paddle {
29+
namespace operators {
30+
31+
class GenNCCLIdOp : public framework::OperatorBase {
32+
public:
33+
GenNCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs,
34+
const framework::VariableNameMap& outputs,
35+
const framework::AttributeMap& attrs)
36+
: OperatorBase(type, inputs, outputs, attrs) {}
37+
38+
void RunImpl(const framework::Scope& scope,
39+
const platform::Place& dev_place) const override {
40+
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
41+
// put nccl id in CPUPlace
42+
auto& dev_ctx = *pool.Get(platform::CPUPlace());
43+
int trainer_id = Attr<int>("trainer_id");
44+
framework::Scope& local_scope = scope.NewScope();
45+
46+
if (trainer_id == 0) {
47+
GenerateAndSend(&local_scope, dev_ctx);
48+
} else {
49+
GetIdByServer(&local_scope, dev_ctx);
50+
}
51+
}
52+
53+
private:
54+
void GenerateAndSend(framework::Scope* scope,
55+
const platform::DeviceContext& dev_ctx) const {
56+
auto var = scope->FindVar(NCCL_ID_VARNAME);
57+
PADDLE_ENFORCE_NOT_NULL(var);
58+
auto id = var->GetMutable<ncclUniqueId>();
59+
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(id));
60+
61+
std::vector<std::string> endpoint_list =
62+
Attr<std::vector<std::string>>("endpoint_list");
63+
detail::RPCClient client;
64+
for (auto& ep : endpoint_list) {
65+
VLOG(3) << "sending nccl id to " << ep;
66+
client.AsyncSendVariable(ep, dev_ctx, *scope, NCCL_ID_VARNAME);
67+
}
68+
client.Wait();
69+
VLOG(3) << "sending completed...";
70+
}
71+
72+
void GetIdByServer(framework::Scope* scope,
73+
const platform::DeviceContext& dev_ctx) const {
74+
std::string endpoint = Attr<std::string>("endpoint");
75+
// NOTE: Can not use unique_ptr here because the default
76+
// deleter will call GRPC Server's base class's dtor and
77+
// that will cause a wired crash.
78+
detail::AsyncGRPCServer rpc_service(endpoint, true);
79+
framework::ProgramDesc empty_program;
80+
framework::Executor executor(dev_ctx.GetPlace());
81+
rpc_service.SetScope(scope);
82+
rpc_service.SetDevCtx(&dev_ctx);
83+
rpc_service.SetProgram(&empty_program);
84+
rpc_service.SetExecutor(&executor);
85+
86+
std::thread server_thread(
87+
std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, &rpc_service));
88+
rpc_service.SetCond(0);
89+
VLOG(3) << "start getting nccl id from trainer 0...";
90+
auto recv = rpc_service.Get();
91+
VLOG(3) << "got nccl id and stop server...";
92+
rpc_service.ShutDown();
93+
VLOG(3) << "rpc server stopped";
94+
server_thread.join();
95+
}
96+
};
97+
98+
class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
99+
public:
100+
void Make() override {
101+
AddOutput("NCCLID", "Raw variable contains a NCCL UniqueId instaces.");
102+
AddComment(R"DOC(
103+
GenNCCLId operator
104+
105+
For trainer 0: generate a new UniqueId and send it to all the other trainers.
106+
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
107+
)DOC");
108+
AddAttr<std::string>("endpoint",
109+
"(string), e.g. 127.0.0.1:6175 "
110+
"current listen endpoint");
111+
AddAttr<std::vector<std::string>>(
112+
"endpoint_list",
113+
"['trainer1_ip:port', 'trainer2_ip:port', ...] "
114+
"list of trainer endpoints start from trainer 1")
115+
.SetDefault({});
116+
AddAttr<int>("trainer_id",
117+
"(int default 0) "
118+
"The index of the trainer in distributed training.")
119+
.SetDefault(0);
120+
}
121+
};
122+
123+
} // namespace operators
124+
} // namespace paddle
125+
126+
namespace ops = paddle::operators;
127+
128+
REGISTER_OPERATOR(gen_nccl_id, ops::GenNCCLIdOp, ops::GenNCCLIdOpMaker);

0 commit comments

Comments
 (0)