Skip to content

Commit 627d7a6

Browse files
authored
Clean sendop recv operator. (#11309)
1 parent fa29ef0 commit 627d7a6

File tree

8 files changed

+39
-162
lines changed

8 files changed

+39
-162
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
8989
for (auto *op : program.Block(0).AllOps()) {
9090
// TODO(Yancey1989): use a graceful method to find send op,
9191
// instead of the the hard code string
92-
if (op->Type() == "send_vars") {
92+
if (op->Type() == "send") {
9393
auto op_vars = op->InputArgumentNames();
9494
send_vars.reserve(send_vars.size() +
9595
std::distance(op_vars.begin(), op_vars.end()));
@@ -468,17 +468,17 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
468468
new RPCOpHandle(op, local_scopes_[0], op.Type(), places_[0]));
469469

470470
if (op.Type() == "send_barrier") {
471-
ConnectOp(result, result->ops_.back().get(), "send_vars");
471+
ConnectOp(result, result->ops_.back().get(), "send");
472472
} else if (op.Type() == "recv") {
473473
ConnectOp(result, result->ops_.back().get(), "send_barrier");
474474
} else if (op.Type() == "fetch_barrier") {
475475
ConnectOp(result, result->ops_.back().get(), "recv");
476-
} else if (op.Type() == "send_vars") {
476+
} else if (op.Type() == "send") {
477477
// do nothing
478478
} else {
479479
PADDLE_THROW(
480480
"rpc op should be in ["
481-
"send_vars, send_barrier. recv, fetch_barrier]");
481+
"send, send_barrier. recv, fetch_barrier]");
482482
}
483483

484484
// TODO(Yancey1989): schedule rpc op on different place may

paddle/fluid/operators/CMakeLists.txt

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,16 +189,14 @@ if(WITH_DISTRIBUTE)
189189

190190
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
191191
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
192-
op_library(send_op DEPS ${DISTRIBUTE_DEPS})
193-
set_source_files_properties(send_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
194192
op_library(prefetch_op DEPS ${DISTRIBUTE_DEPS})
195193
set_source_files_properties(prefetch_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
196194
op_library(recv_op DEPS ${DISTRIBUTE_DEPS})
197195
set_source_files_properties(recv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
198196
op_library(listen_and_serv_op DEPS ${DISTRIBUTE_DEPS})
199197
set_source_files_properties(listen_and_serv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
200-
op_library(send_vars_op DEPS ${DISTRIBUTE_DEPS})
201-
set_source_files_properties(send_vars_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
198+
op_library(send_op DEPS ${DISTRIBUTE_DEPS})
199+
set_source_files_properties(send_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
202200
op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS})
203201
op_library(fetch_barrier_op DEPS ${DISTRIBUTE_DEPS})
204202
set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
@@ -208,15 +206,14 @@ if(WITH_DISTRIBUTE)
208206
# listen_and_serv_op sum_op executor SERIAL)
209207
if(WITH_GPU)
210208
set_source_files_properties(test_send_nccl_id.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
211-
cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS send_op
212-
listen_and_serv_op executor SERIAL)
209+
cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS listen_and_serv_op executor SERIAL)
213210
op_library(gen_nccl_id_op DEPS nccl_common sendrecvop_grpc)
214211
set_source_files_properties(gen_nccl_id_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
215212
else()
216213
set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
217214
endif()
218215
else()
219-
set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op fetch_barrier_op gen_nccl_id_op)
216+
set(DEPS_OPS ${DEPS_OPS} prefetch_op recv_op listen_and_serv_op send_op send_barrier_op fetch_barrier_op gen_nccl_id_op)
220217
endif()
221218

222219
op_library(cross_entropy_op DEPS cross_entropy)

paddle/fluid/operators/recv_op.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,15 @@ This operator can get variables from server side.
7878
}
7979
};
8080

81+
class RecvOpShapeInference : public framework::InferShapeBase {
82+
public:
83+
void operator()(framework::InferShapeContext* ctx) const override {}
84+
};
85+
8186
} // namespace operators
8287
} // namespace paddle
8388

8489
namespace ops = paddle::operators;
8590

86-
REGISTER_OPERATOR(recv, ops::RecvOp, ops::RecvOpMaker);
91+
REGISTER_OPERATOR(recv, ops::RecvOp, paddle::framework::EmptyGradOpMaker,
92+
ops::RecvOpMaker, ops::RecvOpShapeInference);

paddle/fluid/operators/send_op.cc

Lines changed: 12 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ limitations under the License. */
1616
#include <ostream>
1717

1818
#include "paddle/fluid/framework/data_type.h"
19-
#include "paddle/fluid/framework/framework.pb.h"
2019
#include "paddle/fluid/framework/lod_tensor.h"
2120
#include "paddle/fluid/framework/op_registry.h"
2221
#include "paddle/fluid/operators/detail/grpc_client.h"
@@ -36,12 +35,9 @@ class SendOp : public framework::OperatorBase {
3635
void RunImpl(const framework::Scope& scope,
3736
const platform::Place& place) const override {
3837
auto ins = Inputs("X");
39-
auto outs = Outputs("Out");
40-
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
41-
std::vector<std::string> endpoints =
42-
Attr<std::vector<std::string>>("endpoints");
4338

44-
bool sync_mode = Attr<bool>("sync_mode");
39+
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
40+
int sync_send = Attr<int>("sync_mode");
4541

4642
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
4743
auto& ctx = *pool.Get(place);
@@ -55,32 +51,14 @@ class SendOp : public framework::OperatorBase {
5551
for (size_t i = 0; i < ins.size(); i++) {
5652
if (NeedSend(scope, ins[i])) {
5753
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
54+
// TODO(Yancey1989): we need to use an IO threadpool which has
55+
// a larger number of threads than the computing threadpool.
5856
rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]);
5957
} else {
6058
VLOG(3) << "don't send no-initialied variable: " << ins[i];
6159
}
6260
}
63-
rpc_client->Wait();
64-
65-
if (sync_mode) {
66-
for (auto& ep : endpoints) {
67-
VLOG(3) << "batch barrier, ep: " << ep;
68-
rpc_client->AsyncSendBatchBarrier(ep);
69-
}
70-
rpc_client->Wait();
71-
}
72-
73-
if (outs.size() > 0) {
74-
for (size_t i = 0; i < outs.size(); i++) {
75-
VLOG(2) << "getting " << outs[i] << " from " << epmap[i];
76-
rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]);
77-
}
78-
rpc_client->Wait();
79-
// tell pservers that current trainer have called fetch
80-
for (auto& ep : endpoints) {
81-
VLOG(2) << "send fetch barrier, ep: " << ep;
82-
rpc_client->AsyncSendFetchBarrier(ep);
83-
}
61+
if (sync_send) {
8462
rpc_client->Wait();
8563
}
8664
}
@@ -89,26 +67,22 @@ class SendOp : public framework::OperatorBase {
8967
class SendOpMaker : public framework::OpProtoAndCheckerMaker {
9068
public:
9169
void Make() {
92-
AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable();
93-
AddOutput("Out", "(Tensor) Output tensor to be received from server")
70+
AddInput("X", "(Tensor, SelectedRows) Input variables to be sent")
9471
.AsDuplicable();
9572
AddComment(R"DOC(
9673
Send operator
9774
98-
This operator will send tensor to recv_op at the parameter server.
75+
This operator will send variables to listen_and_serve op at the parameter server.
9976
)DOC");
100-
// TODO(typhoonzero): remove this attr generate de-duplicated vector from
101-
// epmap when initializing.
102-
AddAttr<std::vector<std::string>>("endpoints",
103-
"(string vector, default 127.0.0.1:6164)"
104-
"Server endpoints to send variables to.")
105-
.SetDefault({});
77+
AddAttr<int>("sync_mode",
78+
"(int, default 0)"
79+
"sync send or async send.")
80+
.SetDefault(0);
10681
AddAttr<std::vector<std::string>>("epmap",
10782
"(string vector, default 127.0.0.1:6164)"
10883
"Server endpoints in the order of input "
10984
"variables for mapping")
110-
.SetDefault({});
111-
AddAttr<bool>("sync_mode", "work in sync_mode or not").SetDefault(true);
85+
.SetDefault({"127.0.0.1:6164"});
11286
}
11387
};
11488

paddle/fluid/operators/send_vars_op.cc

Lines changed: 0 additions & 101 deletions
This file was deleted.

python/paddle/fluid/tests/unittests/test_dist_transpiler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import unittest
1516
import paddle.fluid as fluid
1617
from paddle.fluid.transpiler.distribute_transpiler import delete_ops
1718

@@ -54,10 +55,10 @@ def get_expect_trainer_ops(self):
5455

5556
delete_ops(trainer.global_block(), optimize_ops)
5657
ops = [op.type for op in trainer.global_block().ops] + [
57-
"split_byref", "send_vars", "send_barrier", "recv", "recv",
58+
"split_byref", "send", "send_barrier", "recv", "recv",
5859
"fetch_barrier", "concat"
5960
]
60-
ops.insert(ops.index("elementwise_add_grad") + 1, "send_vars")
61+
ops.insert(ops.index("elementwise_add_grad") + 1, "send")
6162
return ops
6263

6364

python/paddle/fluid/tests/unittests/test_simple_dist_transpiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ def get_expect_trainer_ops(self):
5959

6060
delete_ops(trainer.global_block(), optimize_ops)
6161
ops = [op.type for op in trainer.global_block().ops] + [
62-
"send_vars", "send_barrier", "recv", "recv", "fetch_barrier"
62+
"send", "send_barrier", "recv", "recv", "fetch_barrier"
6363
]
64-
ops.insert(ops.index("elementwise_add_grad") + 1, "send_vars")
64+
ops.insert(ops.index("elementwise_add_grad") + 1, "send")
6565
return ops
6666

6767
def _transpiler_instance(self):

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2525
2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
2626
3. modify trainer program add split_op to each grad variable.
27-
4. append send_op to send splited variables to server and fetch
28-
params(splited blocks or origin param) from server.
29-
5. append concat_op to merge splited blocks to update local weights.
27+
4. append send_op to send splited variables to server and
28+
5. add recv_op to fetch params(splited blocks or origin param) from server.
29+
6. append concat_op to merge splited blocks to update local weights.
3030
3131
Steps to transpile pserver:
3232
1. create new program for parameter server.
@@ -317,7 +317,7 @@ def transpile(self,
317317

318318
program.global_block().insert_op(
319319
index=index + 1,
320-
type="send_vars",
320+
type="send",
321321
inputs={"X": splited_vars},
322322
outputs={},
323323
attrs={
@@ -678,7 +678,7 @@ def _replace_lookup_table_op_with_prefetch(self, program,
678678
break
679679

680680
def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
681-
# 2. add split_ids_op and send_vars_op to send gradient to pservers
681+
# 2. add split_ids_op and send_op to send gradient to pservers
682682
# there should only be one table_name
683683
all_ops = program.global_block().ops
684684
table_grad_name = grad_var_name(self.table_name)
@@ -695,11 +695,11 @@ def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
695695
outputs={"Out": self.trainer_side_table_grad_list})
696696
program.global_block().insert_op(
697697
index=op_index + 2,
698-
type="send_vars",
698+
type="send",
699699
inputs={'X': self.trainer_side_table_grad_list},
700700
outputs={},
701701
attrs={
702-
"sync_send": True,
702+
"sync_mode": True,
703703
"epmap": pserver_endpoints,
704704
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
705705
})

0 commit comments

Comments
 (0)