Skip to content

Commit 82c61db

Browse files
committed
fix testing
1 parent 0598a4b commit 82c61db

File tree

7 files changed

+117
-101
lines changed

7 files changed

+117
-101
lines changed

paddle/fluid/operators/detail/grpc_client.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
5252
// stub context
5353
SendProcessor* s = new SendProcessor(ch);
5454
s->Prepare(var_h, time_out);
55-
s->response_call_back_ = NULL;
55+
s->response_call_back_ = nullptr;
5656

5757
auto call = s->stub_g_.PrepareUnaryCall(
5858
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);

paddle/fluid/operators/detail/grpc_client.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
5757

5858
class BaseProcessor {
5959
public:
60-
explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) { context_ = NULL; }
60+
explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) {
61+
context_ = nullptr;
62+
}
6163

6264
virtual ~BaseProcessor() {}
6365

@@ -105,7 +107,7 @@ class SendProcessor : public BaseProcessor {
105107

106108
::grpc::GenericStub stub_g_;
107109
::grpc::ByteBuffer reply_;
108-
RequestSendCallBack response_call_back_ = NULL;
110+
RequestSendCallBack response_call_back_ = nullptr;
109111
};
110112

111113
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>

paddle/fluid/operators/detail/grpc_server.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,8 @@ void AsyncGRPCServer::ShutdownQueue() {
261261
// This URL explains why shutdown is complicate:
262262
void AsyncGRPCServer::ShutDown() {
263263
is_shut_down_ = true;
264-
ShutdownQueue();
265264
server_->Shutdown();
265+
ShutdownQueue();
266266
}
267267

268268
void AsyncGRPCServer::TryToRegisterNewSendOne() {

paddle/fluid/operators/detail/grpc_server.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class AsyncGRPCServer final {
4747
explicit AsyncGRPCServer(const std::string &address, bool sync_mode)
4848
: address_(address), sync_mode_(sync_mode) {}
4949

50+
~AsyncGRPCServer() {}
51+
5052
void RunSyncUpdate();
5153

5254
// functions to sync server barrier status.

paddle/fluid/operators/detail/sendrecvop_utils.cc

Lines changed: 83 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -53,109 +53,106 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
5353
e.WriteUint64(VarMsg::kTypeFieldNumber, 1);
5454
} else if (var->IsType<ncclUniqueId>()) {
5555
// NOTE: sendrecv only support RAW type for NCCL_ID
56+
VLOG(3) << "serilizing: setting var type nccl id";
5657
e.WriteUint64(VarMsg::kTypeFieldNumber, 2);
5758
}
5859

5960
if (!out_name.empty()) {
6061
e.WriteString(VarMsg::kOutVarnameFieldNumber, out_name);
6162
}
62-
switch (framework::ToVarType(var->Type())) {
63-
case framework::proto::VarType_Type_LOD_TENSOR: {
64-
auto tensor = var->Get<framework::LoDTensor>();
65-
e.WriteUint64(VarMsg::kDataTypeFieldNumber,
66-
framework::ToDataType(tensor.type()));
67-
for (auto& dim : framework::vectorize(tensor.dims())) {
68-
e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
69-
}
70-
auto lod = tensor.lod(); // std::vector<Vector<size_t>>
71-
if (lod.size() > 0) {
72-
e.WriteUint64(VarMsg::kLodLevelFieldNumber, lod.size());
73-
74-
for (auto& each : lod) {
75-
e.WriteVarlengthBeginning(VarMsg::kLodFieldNumber,
76-
2 + // tag + varintlength of submessage
77-
1 + // kLodDataFieldNumber
78-
each.size());
79-
// auto copied from GPU
80-
for (auto& d : each) {
81-
e.WriteUint64(VarMsg::LodData::kLodDataFieldNumber, d);
82-
}
63+
if (var->IsType<framework::LoDTensor>()) {
64+
// ===========================Tensor==================================
65+
auto tensor = var->Get<framework::LoDTensor>();
66+
e.WriteUint64(VarMsg::kDataTypeFieldNumber,
67+
framework::ToDataType(tensor.type()));
68+
for (auto& dim : framework::vectorize(tensor.dims())) {
69+
e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
70+
}
71+
auto lod = tensor.lod(); // std::vector<Vector<size_t>>
72+
if (lod.size() > 0) {
73+
e.WriteUint64(VarMsg::kLodLevelFieldNumber, lod.size());
74+
75+
for (auto& each : lod) {
76+
e.WriteVarlengthBeginning(VarMsg::kLodFieldNumber,
77+
2 + // tag + varintlength of submessage
78+
1 + // kLodDataFieldNumber
79+
each.size());
80+
// auto copied from GPU
81+
for (auto& d : each) {
82+
e.WriteUint64(VarMsg::LodData::kLodDataFieldNumber, d);
8383
}
8484
}
85-
if (platform::is_gpu_place(ctx.GetPlace())) {
85+
}
86+
if (platform::is_gpu_place(ctx.GetPlace())) {
8687
#ifdef PADDLE_WITH_CUDA
87-
PADDLE_ENFORCE(platform::is_gpu_place(tensor.place()));
88+
PADDLE_ENFORCE(platform::is_gpu_place(tensor.place()));
89+
platform::CPUPlace cpu;
90+
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
91+
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
92+
payload = memory::Alloc(cpu, copy_size);
93+
94+
memory::Copy(cpu, payload,
95+
boost::get<platform::CUDAPlace>(tensor.place()),
96+
reinterpret_cast<const void*>(tensor.data<void>()),
97+
copy_size, gpu_dev_ctx.stream());
98+
ctx.Wait();
99+
destroy_callback = [](void* backing) {
88100
platform::CPUPlace cpu;
89-
auto& gpu_dev_ctx =
90-
static_cast<const platform::CUDADeviceContext&>(ctx);
91-
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
92-
payload = memory::Alloc(cpu, copy_size);
93-
94-
memory::Copy(cpu, payload,
95-
boost::get<platform::CUDAPlace>(tensor.place()),
96-
reinterpret_cast<const void*>(tensor.data<void>()),
97-
copy_size, gpu_dev_ctx.stream());
98-
ctx.Wait();
99-
destroy_callback = [](void* backing) {
100-
platform::CPUPlace cpu;
101-
memory::Free(cpu, backing);
102-
};
101+
memory::Free(cpu, backing);
102+
};
103103

104104
#endif
105-
} else {
106-
payload = tensor.data<void>();
107-
}
108-
payload_size = tensor.numel() * framework::SizeOfType(tensor.type());
109-
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
110-
} break;
111-
case framework::proto::VarType_Type_SELECTED_ROWS: {
112-
// TODO(typhoonzero): selectedrows implement should not use unique_ptr
113-
auto* slr = var->GetMutable<framework::SelectedRows>();
114-
e.WriteUint64(VarMsg::kDataTypeFieldNumber,
115-
framework::ToDataType(slr->value().type()));
116-
for (auto& dim : framework::vectorize(slr->value().dims())) {
117-
e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
118-
}
119-
e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0);
120-
e.WriteUint64(VarMsg::kSlrHeightFieldNumber, slr->height());
121-
auto* tensor = slr->mutable_value();
122-
if (platform::is_gpu_place(ctx.GetPlace())) {
105+
} else {
106+
payload = tensor.data<void>();
107+
}
108+
payload_size = tensor.numel() * framework::SizeOfType(tensor.type());
109+
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
110+
} else if (var->IsType<framework::SelectedRows>()) {
111+
// ===========================SELECTED
112+
// ROWS==================================
113+
// TODO(typhoonzero): selectedrows implement should not use unique_ptr
114+
auto* slr = var->GetMutable<framework::SelectedRows>();
115+
e.WriteUint64(VarMsg::kDataTypeFieldNumber,
116+
framework::ToDataType(slr->value().type()));
117+
for (auto& dim : framework::vectorize(slr->value().dims())) {
118+
e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
119+
}
120+
e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0);
121+
e.WriteUint64(VarMsg::kSlrHeightFieldNumber, slr->height());
122+
auto* tensor = slr->mutable_value();
123+
if (platform::is_gpu_place(ctx.GetPlace())) {
123124
#ifdef PADDLE_WITH_CUDA
125+
platform::CPUPlace cpu;
126+
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
127+
auto copy_size = tensor->numel() * framework::SizeOfType(tensor->type());
128+
payload = memory::Alloc(cpu, copy_size);
129+
memory::Copy(cpu, payload,
130+
boost::get<platform::CUDAPlace>(tensor->place()),
131+
reinterpret_cast<const void*>(tensor->data<void>()),
132+
copy_size, gpu_dev_ctx.stream());
133+
ctx.Wait();
134+
destroy_callback = [](void* backing) {
124135
platform::CPUPlace cpu;
125-
auto& gpu_dev_ctx =
126-
static_cast<const platform::CUDADeviceContext&>(ctx);
127-
auto copy_size =
128-
tensor->numel() * framework::SizeOfType(tensor->type());
129-
payload = memory::Alloc(cpu, copy_size);
130-
memory::Copy(cpu, payload,
131-
boost::get<platform::CUDAPlace>(tensor->place()),
132-
reinterpret_cast<const void*>(tensor->data<void>()),
133-
copy_size, gpu_dev_ctx.stream());
134-
ctx.Wait();
135-
destroy_callback = [](void* backing) {
136-
platform::CPUPlace cpu;
137-
memory::Free(cpu, backing);
138-
};
136+
memory::Free(cpu, backing);
137+
};
139138
#endif
140-
} else {
141-
payload = slr->mutable_value()->data<void>();
142-
}
143-
payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
144-
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
145-
} break;
146-
case framework::proto::VarType_Type_RAW: {
147-
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
148-
NCCL_UNIQUE_ID_BYTES);
149-
ncclUniqueId* uid = var->GetMutable<ncclUniqueId>();
150-
e.WriteRawBytes(std::string(uid->internal, NCCL_UNIQUE_ID_BYTES));
151-
} break;
152-
default:
153-
PADDLE_THROW("Serialize does not support type: %s",
154-
typeid(var->Type()).name());
155-
break;
139+
} else {
140+
payload = slr->mutable_value()->data<void>();
141+
}
142+
payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
143+
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
144+
} else if (var->IsType<ncclUniqueId>()) {
145+
// ===========================NCCL ID==================================
146+
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
147+
NCCL_UNIQUE_ID_BYTES);
148+
ncclUniqueId* uid = var->GetMutable<ncclUniqueId>();
149+
e.WriteRawBytes(std::string(uid->internal, NCCL_UNIQUE_ID_BYTES));
150+
} else {
151+
PADDLE_THROW("Serialize does not support type: %s",
152+
typeid(var->Type()).name());
156153
}
157154

158-
if (framework::ToVarType(var->Type()) == framework::proto::VarType_Type_RAW) {
155+
if (var->IsType<ncclUniqueId>()) {
159156
// for serialize NCCL_ID
160157
::grpc::Slice slices(e.size());
161158
memcpy(const_cast<uint8_t*>(slices.begin()), e.data(), e.size());

paddle/fluid/operators/detail/variable_response.cc

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -371,19 +371,26 @@ int VariableResponse::Parse(Source* source) {
371371
meta_.type() == sendrecv::NCCL_ID) &&
372372
meta_.varname() != "",
373373
"meta info should be got first!");
374+
int length = 0;
375+
if (wt != WIRETYPE_LENGTH_DELIMITED ||
376+
!ReadVarintSizeAsInt(&input, &length)) {
377+
return tag;
378+
}
379+
374380
if (meta_.type() == sendrecv::NCCL_ID) {
381+
VLOG(3) << "parse nccl id request";
375382
auto* var = scope_->FindVar(meta_.varname());
376383
if (var != nullptr) {
384+
VLOG(3) << "parse nccl id: length " << length;
377385
ncclUniqueId* id = var->GetMutable<ncclUniqueId>();
378-
memcpy(id->internal, meta_.serialized().c_str(),
379-
meta_.serialized().size());
386+
if (!ReadRaw(&input, *dev_ctx_, platform::CPUPlace(), id->internal,
387+
length)) {
388+
return tag;
389+
}
390+
// memcpy(id->internal, meta_.serialized().c_str(),
391+
// meta_.serialized().size());
380392
}
381-
}
382-
383-
int length = 0;
384-
if (wt != WIRETYPE_LENGTH_DELIMITED ||
385-
!ReadVarintSizeAsInt(&input, &length)) {
386-
return tag;
393+
break;
387394
}
388395

389396
framework::DDim dims = GetDims(meta_.dims());

paddle/fluid/operators/gen_nccl_id_op.cc

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ class GenNCCLIdOp : public framework::OperatorBase {
3737
void RunImpl(const framework::Scope& scope,
3838
const platform::Place& dev_place) const override {
3939
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
40-
auto& dev_ctx = *pool.Get(dev_place);
40+
// put nccl id in CPUPlace
41+
auto& dev_ctx = *pool.Get(platform::CPUPlace());
4142
int trainer_id = Attr<int>("trainer_id");
4243
framework::Scope& local_scope = scope.NewScope();
4344

@@ -60,9 +61,11 @@ class GenNCCLIdOp : public framework::OperatorBase {
6061
Attr<std::vector<std::string>>("endpoint_list");
6162
detail::RPCClient client;
6263
for (auto& ep : endpoint_list) {
64+
VLOG(3) << "sending nccl id to " << ep;
6365
client.AsyncSendVariable(ep, dev_ctx, *scope, "NCCLID");
6466
}
6567
client.Wait();
68+
VLOG(3) << "sending completed...";
6669
}
6770

6871
void GetIdByServer(framework::Scope* scope,
@@ -78,9 +81,14 @@ class GenNCCLIdOp : public framework::OperatorBase {
7881

7982
server_thread_.reset(new std::thread(std::bind(
8083
&detail::AsyncGRPCServer::RunSyncUpdate, rpc_service_.get())));
81-
84+
rpc_service_->SetCond(0);
85+
VLOG(3) << "start getting nccl id from trainer 0...";
8286
auto recv = rpc_service_->Get();
83-
rpc_service_->ShutDown();
87+
VLOG(3) << "got nccl id and stop server...";
88+
// rpc_service_->SetCond(1);
89+
// rpc_service_->ShutDown();
90+
rpc_service->Push(LISTEN_TERMINATE_MESSAGE);
91+
VLOG(3) << "rpc server stopped";
8492
// TODO(wuyi): reinit nccl communicators
8593
}
8694

0 commit comments

Comments
 (0)