Skip to content

Commit 315e44a

Browse files
committed
add fetch_barrier_op
1 parent b35ea1a commit 315e44a

File tree

4 files changed

+126
-48
lines changed

4 files changed

+126
-48
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
353353
scope->DeleteScope(local_scope);
354354
} else {
355355
// Delete the local scopes created in operators.
356-
scope->DropKids();
356+
// scope->DropKids();
357357
}
358358
if (FLAGS_benchmark) {
359359
VLOG(2) << "-------------------------------------------------------";

paddle/fluid/operators/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,11 +199,13 @@ if(WITH_DISTRIBUTE)
199199
op_library(send_vars_op DEPS ${DISTRIBUTE_DEPS})
200200
set_source_files_properties(send_vars_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
201201
op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS})
202+
op_library(fetch_barrier_op DEPS ${DISTRIBUTE_DEPS})
202203
set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
204+
set_source_files_properties(fetch_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
203205
set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
204206
cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op listen_and_serv_op sum_op executor)
205207
else()
206-
set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op)
208+
set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op fetch_barrier_op)
207209
endif()
208210

209211
op_library(cross_entropy_op DEPS cross_entropy)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <future> // NOLINT
16+
#include <ostream>
17+
18+
#include "paddle/fluid/framework/data_type.h"
19+
#include "paddle/fluid/framework/framework.pb.h"
20+
#include "paddle/fluid/framework/lod_tensor.h"
21+
#include "paddle/fluid/framework/op_registry.h"
22+
23+
#include "paddle/fluid/operators/detail/grpc_client.h"
24+
25+
namespace paddle {
26+
namespace operators {
27+
28+
class FetchBarrierOp : public framework::OperatorBase {
29+
public:
30+
FetchBarrierOp(const std::string& type,
31+
const framework::VariableNameMap& inputs,
32+
const framework::VariableNameMap& outputs,
33+
const framework::AttributeMap& attrs)
34+
: OperatorBase(type, inputs, outputs, attrs) {}
35+
36+
void RunImpl(const framework::Scope& scope,
37+
const platform::Place& place) const override {
38+
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
39+
40+
auto client_var_name = Output("RPCClient");
41+
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
42+
"Can not find variable '%s' in the scope.",
43+
client_var_name);
44+
auto* client_var = scope.FindVar(client_var_name);
45+
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
46+
47+
PADDLE_ENFORCE(rpc_client->Wait());
48+
49+
for (auto& ep : eps) {
50+
VLOG(3) << "fetch barrier, ep: " << ep;
51+
rpc_client->AsyncSendFetchBarrier(ep);
52+
}
53+
PADDLE_ENFORCE(rpc_client->Wait());
54+
}
55+
};
56+
57+
class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
58+
public:
59+
void Make() {
60+
AddOutput("RPCClient",
61+
"(RPCClient) The RPC client object which is"
62+
"initialized at most once.");
63+
AddComment(R"DOC(
64+
SendBarrier operator
65+
66+
This operator will send a send barrier signal to list_and_serv op, so that
67+
the Parameter Server would knew all variables have been sent.
68+
)DOC");
69+
70+
AddAttr<std::vector<std::string>>("endpoints",
71+
"(string vector, default 127.0.0.1:6164)"
72+
"Server endpoints to send variables to.")
73+
.SetDefault({"127.0.0.1:6164"});
74+
}
75+
};
76+
77+
class FetchBarrierOpVarTypeInference : public framework::VarTypeInference {
78+
public:
79+
void operator()(const framework::OpDesc& op_desc,
80+
framework::BlockDesc* block) const override {
81+
auto out_var_name = op_desc.Output("RPCClient").front();
82+
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
83+
auto var_type = framework::proto::VarType::RAW;
84+
out_var.SetType(var_type);
85+
}
86+
};
87+
88+
class FetchBarrierOpShapeInference : public framework::InferShapeBase {
89+
public:
90+
void operator()(framework::InferShapeContext* ctx) const override {}
91+
};
92+
93+
} // namespace operators
94+
} // namespace paddle
95+
96+
namespace ops = paddle::operators;
97+
98+
REGISTER_OPERATOR(fetch_barrier, ops::FetchBarrierOp,
99+
paddle::framework::EmptyGradOpMaker, ops::FetchBarrierOpMaker,
100+
ops::FetchBarrierOpVarTypeInference,
101+
ops::FetchBarrierOpShapeInference);

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 21 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -315,12 +315,22 @@ def transpile(self,
315315
# step 3.1: insert send op to send gradient vars to parameter servers
316316
ps_dispatcher.reset()
317317
send_vars = []
318-
for varname, splited_vars in grad_var_mapping.items():
319-
index = find_op_by_output_arg(program.global_block(), varname)
318+
for orig_varname, splited_vars in grad_var_mapping.items():
320319
eplist = ps_dispatcher.dispatch(splited_vars)
321-
if len(splited_vars) > 1:
322-
self._insert_split_op(program, varname, splited_vars)
320+
if len(splited_vars) == 1:
321+
orig_varname = splited_vars[0].name
322+
index = find_op_by_output_arg(program.global_block(),
323+
orig_varname)
324+
elif len(splited_vars) > 1:
325+
orig_var = program.global_block().vars[orig_varname]
326+
index = find_op_by_output_arg(program.global_block(),
327+
orig_varname)
328+
self._insert_split_op(program, orig_var, index, splited_vars)
323329
index += 1
330+
else:
331+
AssertionError("Can not insert the send op by original "
332+
"variable name :", orig_varname)
333+
324334
program.global_block().insert_op(
325335
index=index + 1,
326336
type="send_vars",
@@ -351,6 +361,12 @@ def transpile(self,
351361
"RPCClient": rpc_client_var},
352362
attrs={"epmap": eplist})
353363

364+
program.global_block().append_op(
365+
type="fetch_barrier",
366+
inputs={},
367+
outputs={"RPCClient": rpc_client_var},
368+
attrs={"endpoints": pserver_endpoints})
369+
354370
for i, ep in enumerate(eplist):
355371
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
356372
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
@@ -859,9 +875,7 @@ def _clone_var(self, block, var, persistable=True):
859875
lod_level=var.lod_level,
860876
persistable=persistable)
861877

862-
def _insert_split_op(self, program, orig_varname, splited_vars):
863-
orig_var = program.global_block().vars[orig_varname]
864-
index = find_op_by_output_arg(program.global_block(), orig_varname)
878+
def _insert_split_op(self, program, orig_var, index, splited_vars):
865879
if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS:
866880
height_sections = []
867881
for v in splited_vars:
@@ -887,45 +901,6 @@ def _insert_split_op(self, program, orig_varname, splited_vars):
887901
AssertionError("Variable type should be in set "
888902
"[LOD_TENSOR, SELECTED_ROWS]")
889903

890-
def _append_split_op(self, program, gradblocks):
891-
# Split variables that need to be split and append respective ops
892-
add_suffix = False
893-
if self.trainer_num > 1:
894-
add_suffix = True
895-
var_mapping = self._create_vars_from_blocklist(
896-
program, gradblocks, add_trainer_suffix=add_suffix)
897-
for varname, splited_vars in var_mapping.iteritems():
898-
# variable that don't need to split have empty splited_vars
899-
if len(splited_vars) <= 1:
900-
continue
901-
orig_var = program.global_block().vars[varname]
902-
index = find_op_by_output_arg(program.global_block(), orig_var.name)
903-
if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS:
904-
height_sections = []
905-
for v in splited_vars:
906-
height_sections.append(v.shape[0])
907-
program.global_block().insert_op(
908-
index=index + 1,
909-
type="split_selected_rows",
910-
inputs={"X": orig_var},
911-
outputs={"Out": splited_vars},
912-
attrs={"height_sections": height_sections})
913-
elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR:
914-
sections = []
915-
for v in splited_vars:
916-
sections.append(v.shape[0])
917-
program.global_block().insert_op(
918-
index=index + 1,
919-
type="split_byref",
920-
inputs={"X": orig_var},
921-
outputs={"Out": splited_vars},
922-
attrs={"sections": sections} # assume split evenly
923-
)
924-
else:
925-
AssertionError("Variable type should be in set "
926-
"[LOD_TENSOR, SELECTED_ROWS]")
927-
return var_mapping
928-
929904
def _get_optimizer_input_shape(self, op_type, varkey, orig_shape,
930905
param_shape):
931906
"""

0 commit comments

Comments
 (0)