Skip to content

Commit 992e38a

Browse files
authored
Merge pull request #14672 from jacquesqiao/cherry-pick-refactor-prefetch
Merge pull request #14589 from jacquesqiao/refactor-prefetch
2 parents 8226d44 + 5b87198 commit 992e38a

26 files changed

+853
-123
lines changed

paddle/fluid/framework/details/multi_devices_graph_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -862,7 +862,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
862862
if (node->Op()->Type() == "fetch_barrier") {
863863
outvar_dev_id =
864864
GetVarDeviceID(*result, output->Name(), *sharded_var_device);
865-
PADDLE_ENFORCE_NE(outvar_dev_id, -1);
865+
PADDLE_ENFORCE_NE(outvar_dev_id, -1, "output name %s", output->Name());
866866
}
867867
p = places_[outvar_dev_id];
868868
ir::Node *new_node = nullptr;

paddle/fluid/operators/CMakeLists.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,13 @@ if (WITH_GPU)
3737
SET(OP_HEADER_DEPS ${OP_HEADER_DEPS} cub)
3838
endif()
3939

40-
register_operators(EXCLUDES warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS})
40+
SET(OP_PREFETCH_DEPS "")
41+
if (WITH_DISTRIBUTE)
42+
SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch)
43+
endif()
44+
45+
register_operators(EXCLUDES warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
46+
4147

4248
# warpctc_op needs cudnn 7 above
4349
if (WITH_GPU AND NOT WIN32)

paddle/fluid/operators/distributed/CMakeLists.txt

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,36 +9,37 @@ else()
99
endif()
1010
configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @ONLY)
1111

12+
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
13+
1214
if(WITH_GRPC)
1315
grpc_library(sendrecvop_grpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
1416
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc
1517
PROTO send_recv.proto
1618
DEPS lod_tensor selected_rows memory)
17-
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
19+
1820
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
1921
cc_test(grpc_serde_test SRCS grpc_serde_test.cc
2022
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
2123
cc_test(rpc_server_test SRCS rpc_server_test.cc
2224
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
2325
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler)
24-
return()
25-
endif()
26-
27-
28-
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
26+
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_grpc memory)
27+
else()
28+
set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
29+
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
2930

30-
set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
31-
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
31+
brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
32+
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc
33+
PROTO send_recv.proto
34+
DEPS lod_tensor selected_rows memory)
3235

33-
brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
34-
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc
35-
PROTO send_recv.proto
36-
DEPS lod_tensor selected_rows memory)
36+
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_brpc memory)
3737

38-
set(brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy)
38+
set(brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy)
3939

40-
cc_test(brpc_server_test SRCS rpc_server_test.cc
41-
DEPS ${brpc_test_depends} SERIAL)
40+
cc_test(brpc_server_test SRCS rpc_server_test.cc
41+
DEPS ${brpc_test_depends} SERIAL)
4242

43-
cc_test(brpc_serde_test SRCS brpc_serde_test.cc
44-
DEPS ${brpc_test_depends} SERIAL)
43+
cc_test(brpc_serde_test SRCS brpc_serde_test.cc
44+
DEPS ${brpc_test_depends} SERIAL)
45+
endif()

paddle/fluid/operators/distributed/grpc_client.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,11 +171,13 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
171171
const framework::Scope& scope,
172172
const std::string& in_var_name,
173173
const std::string& out_var_name,
174+
const std::string& table_name,
174175
int64_t time_out) {
175176
const platform::DeviceContext* p_ctx = &ctx;
176177
const std::string ep_val = ep;
177178
const std::string in_var_name_val = in_var_name;
178179
const std::string out_var_name_val = out_var_name;
180+
const std::string table_name_val = table_name;
179181
const framework::Scope* p_scope = &scope;
180182
const auto ch = GetChannel(ep_val);
181183
GetProcessor* s = new GetProcessor(ch);
@@ -186,11 +188,12 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
186188
s->Prepare(h, time_out);
187189

188190
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
189-
s, method, h, this] {
191+
s, method, h, table_name_val, this] {
190192
auto* var = p_scope->FindVar(in_var_name_val);
191193

192194
::grpc::ByteBuffer req;
193-
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val);
195+
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val,
196+
0, table_name_val);
194197

195198
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
196199

paddle/fluid/operators/distributed/grpc_client.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ class GRPCClient : public RPCClient {
194194
const framework::Scope& scope,
195195
const std::string& in_var_name,
196196
const std::string& out_var_name,
197+
const std::string& table_name = "",
197198
int64_t time_out = FLAGS_rpc_deadline) override;
198199

199200
VarHandlePtr AsyncSendBatchBarrier(

paddle/fluid/operators/distributed/grpc_serde.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ static void SerializeDestroyCallback(void* payload) {
4242
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
4343
const platform::DeviceContext& ctx,
4444
::grpc::ByteBuffer* msg, const std::string& out_name,
45-
const int trainer_id) {
45+
const int trainer_id,
46+
const std::string& table_name) {
4647
platform::RecordRPCEvent record_event("serial", &ctx);
4748
VarMsg request;
4849
TensorPayload* payload = nullptr;
@@ -63,6 +64,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
6364
if (!out_name.empty()) {
6465
request.set_out_varname(out_name);
6566
}
67+
if (!table_name.empty()) {
68+
request.set_table_name(table_name);
69+
}
6670
if (var->IsType<framework::LoDTensor>()) {
6771
request.set_type(::sendrecv::LOD_TENSOR);
6872
payload = new TensorPayload(GetTensorPayload(var, ctx, &request));

paddle/fluid/operators/distributed/grpc_serde.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
4040
const platform::DeviceContext& ctx,
4141
::grpc::ByteBuffer* msg,
4242
const std::string& out_varname = std::string(),
43-
const int trainer_id = 0);
43+
const int trainer_id = 0,
44+
const std::string& table_name = std::string());
4445

4546
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
4647
const platform::DeviceContext& ctx,

paddle/fluid/operators/distributed/grpc_serde_test.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
130130
math::set_constant(ctx, tensor, 31.9);
131131

132132
::grpc::ByteBuffer msg;
133-
operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg);
133+
operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg,
134+
"outvar", 0, "table_name");
134135
EXPECT_GT(msg.Length(), static_cast<size_t>(0));
135136

136137
// deserialize

paddle/fluid/operators/distributed/grpc_server.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ class RequestPrefetch final : public RequestBase {
183183
// prefetch process...
184184
std::string in_var_name = request_->Varname();
185185
std::string out_var_name = request_->OutVarname();
186+
std::string table_name = request_->TableName();
186187
int trainer_id = request_->GetTrainerId();
187188
VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name
188189
<< " out_var_name: " << out_var_name;
@@ -193,7 +194,7 @@ class RequestPrefetch final : public RequestBase {
193194
framework::Variable* outvar = scope->Var(out_var_name);
194195

195196
request_handler_->Handle(in_var_name, scope, invar, &outvar, trainer_id,
196-
out_var_name);
197+
out_var_name, table_name);
197198

198199
SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
199200
&reply_);

paddle/fluid/operators/distributed/grpc_variable_response.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,20 @@ int GRPCVariableResponse::Parse(Source* source) {
301301
meta_.set_trainer_id(trainer_id);
302302
break;
303303
}
304+
case sendrecv::VariableMessage::kTableNameFieldNumber: {
305+
uint32_t length;
306+
if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) {
307+
return tag;
308+
}
309+
310+
std::string temp;
311+
if (!input.ReadString(&temp, length)) {
312+
return tag;
313+
}
314+
315+
meta_.set_table_name(temp);
316+
break;
317+
}
304318
default: {
305319
// Unknown tag, return unknown error.
306320
return -1;

0 commit comments

Comments
 (0)