Skip to content

Commit f5840d8

Browse files
committed
follow comments
1 parent 04bde96 commit f5840d8

File tree

7 files changed

+25
-22
lines changed

7 files changed

+25
-22
lines changed

paddle/fluid/framework/parallel_executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ ParallelExecutor::ParallelExecutor(
8080

8181
// Bcast Parameters to all GPUs
8282
#ifdef PADDLE_WITH_CUDA
83-
auto *nccl_id_var = scope->FindVar("NCCLID");
83+
auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
8484
ncclUniqueId *nccl_id = nullptr;
8585
if (nccl_id_var != nullptr) {
8686
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();

paddle/fluid/operators/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ if(WITH_DISTRIBUTE)
187187
if(WITH_GPU)
188188
op_library(gen_nccl_id_op DEPS nccl_common)
189189
else()
190-
set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
190+
set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
191191
endif()
192192
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
193193
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")

paddle/fluid/operators/detail/sendrecvop_utils.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
162162
if (var->IsType<ncclUniqueId>()) {
163163
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
164164
NCCL_UNIQUE_ID_BYTES);
165-
ncclUniqueId* uid = var->GetMutable<ncclUniqueId>();
166-
e.WriteRawBytes(std::string(uid->internal, NCCL_UNIQUE_ID_BYTES));
165+
ncclUniqueId& uid = var->Get<ncclUniqueId>();
166+
e.WriteRawBytes(std::string(uid.internal, NCCL_UNIQUE_ID_BYTES));
167167

168168
// for serialize NCCL_ID
169169
::grpc::Slice slices(e.size());

paddle/fluid/operators/gen_nccl_id_op.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,17 @@ class GenNCCLIdOp : public framework::OperatorBase {
5252
private:
5353
void GenerateAndSend(framework::Scope* scope,
5454
const platform::DeviceContext& dev_ctx) const {
55-
auto var = scope->FindVar("NCCLID");
55+
auto var = scope->FindVar(NCCL_ID_VARNAME);
5656
PADDLE_ENFORCE_NOT_NULL(var);
5757
auto id = var->GetMutable<ncclUniqueId>();
58-
platform::dynload::ncclGetUniqueId(id);
58+
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(id));
5959

6060
std::vector<std::string> endpoint_list =
6161
Attr<std::vector<std::string>>("endpoint_list");
6262
detail::RPCClient client;
6363
for (auto& ep : endpoint_list) {
6464
VLOG(3) << "sending nccl id to " << ep;
65-
client.AsyncSendVariable(ep, dev_ctx, *scope, "NCCLID");
65+
client.AsyncSendVariable(ep, dev_ctx, *scope, NCCL_ID_VARNAME);
6666
}
6767
client.Wait();
6868
VLOG(3) << "sending completed...";
@@ -71,6 +71,9 @@ class GenNCCLIdOp : public framework::OperatorBase {
7171
void GetIdByServer(framework::Scope* scope,
7272
const platform::DeviceContext& dev_ctx) const {
7373
std::string endpoint = Attr<std::string>("endpoint");
74+
// NOTE: Can not use unique_ptr here because the default
75+
// deleter will call GRPC Server's base class's dtor and
76+
// that will cause a wired crash.
7477
rpc_service_ = new detail::AsyncGRPCServer(endpoint, true);
7578
framework::ProgramDesc empty_program;
7679
framework::Executor executor(dev_ctx.GetPlace());

paddle/fluid/operators/test_send_nccl_id.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ std::unique_ptr<detail::AsyncGRPCServer> rpc_service;
3939
void StartServer() {
4040
f::Scope scope;
4141
p::CPUPlace place;
42-
scope.Var("NCCLID");
42+
scope.Var(NCCL_ID_VARNAME);
4343
p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
4444
auto& dev_ctx = *pool.Get(p::CPUPlace());
4545

@@ -71,7 +71,7 @@ TEST(SendNcclId, Normal) {
7171
p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
7272
auto& dev_ctx = *pool.Get(p::CPUPlace());
7373

74-
auto var = scope.Var("NCCLID");
74+
auto var = scope.Var(NCCL_ID_VARNAME);
7575
// var->SetType(f::proto::VarType_Type_RAW);
7676
auto id = var->GetMutable<ncclUniqueId>();
7777
p::dynload::ncclGetUniqueId(id);
@@ -80,7 +80,7 @@ TEST(SendNcclId, Normal) {
8080
std::string ep = string::Sprintf("127.0.0.1:%d", port);
8181
detail::RPCClient client;
8282

83-
client.AsyncSendVariable(ep, dev_ctx, scope, "NCCLID");
83+
client.AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME);
8484
client.Wait();
8585
server_thread.join();
8686
auto* ptr = rpc_service.release();

paddle/fluid/platform/nccl_helper.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include "paddle/fluid/platform/dynload/nccl.h"
2222
#include "paddle/fluid/platform/enforce.h"
2323

24+
#define NCCL_ID_VARNAME "NCCLID"
25+
2426
namespace paddle {
2527
namespace platform {
2628

@@ -76,7 +78,7 @@ struct NCCLContextMap {
7678

7779
explicit NCCLContextMap(const std::vector<platform::Place> &places,
7880
ncclUniqueId *nccl_id = nullptr,
79-
size_t node_count = 0, size_t trainer_id = 0) {
81+
size_t num_trainers = 0, size_t trainer_id = 0) {
8082
PADDLE_ENFORCE(!places.empty());
8183
order_.reserve(places.size());
8284
for (auto &p : places) {
@@ -94,16 +96,14 @@ struct NCCLContextMap {
9496
std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
9597
// if pass nccl_id here, can assume we are doing multi node training
9698
if (nccl_id == nullptr) {
97-
{
98-
std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex());
99-
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
100-
comms.get(), static_cast<int>(order_.size()), order_.data()));
101-
}
99+
std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex());
100+
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
101+
comms.get(), static_cast<int>(order_.size()), order_.data()));
102102
} else {
103-
PADDLE_ENFORCE_GT(node_count, 0);
103+
PADDLE_ENFORCE_GT(num_trainers, 0);
104104
// TODO(wuyi): need to ensure each node have same number of GPUs
105105
{
106-
int nranks = node_count * order_.size();
106+
int nranks = num_trainers * order_.size();
107107
NCCLGroupGuard gurad;
108108
for (auto &gpu_id : order_) {
109109
int rank = trainer_id * order_.size() + gpu_id;

python/paddle/fluid/parallel_executor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(self,
3131
allow_op_delay=False,
3232
share_vars_from=None,
3333
use_default_grad_scale=True,
34-
num_nodes=0,
34+
num_trainers=0,
3535
trainer_id=0):
3636
"""
3737
ParallelExecutor can run program in parallel.
@@ -53,10 +53,10 @@ def __init__(self,
5353
gradients of each device and scaled gradients would be
5454
aggregated. Otherwise, a customized scale value should be fed
5555
to the network.
56-
num_nodes(int, default 0): If greater than 0, NCCL will be
56+
num_trainers(int, default 0): If greater than 0, NCCL will be
5757
initialized with multpile rank of nodes, each node should have
5858
same number of GPUs. Distributed training will be enabled then.
59-
trainer_id(int, default 0): Must use together with num_nodes.
59+
trainer_id(int, default 0): Must use together with num_trainers.
6060
trainer_id is the "rank" of current node starts from 0.
6161
6262
Returns:
@@ -137,7 +137,7 @@ def __init__(self,
137137
local_scopes,
138138
allow_op_delay,
139139
use_default_grad_scale,
140-
num_nodes,
140+
num_trainers,
141141
trainer_id)
142142
self.scope = scope
143143

0 commit comments

Comments
 (0)