Skip to content

Commit 633756a

Browse files
authored
Merge pull request #8361 from tonyyang-svail/backward_on_parallel_do
Backward on parallel do using nccl
2 parents a040239 + 4b957af commit 633756a

File tree

11 files changed

+191
-41
lines changed

11 files changed

+191
-41
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,13 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
5555
var->GetMutable<platform::PlaceList>();
5656
} else if (var_type == proto::VarType::READER) {
5757
var->GetMutable<ReaderHolder>();
58+
} else if (var_type == proto::VarType::NCCL_COM) {
59+
// GetMutable will be called in ncclInit
5860
} else {
5961
PADDLE_THROW(
6062
"Variable type %d is not in "
6163
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
62-
"LOD_RANK_TABLE, PLACE_LIST, READER]",
64+
"LOD_RANK_TABLE, PLACE_LIST, READER, NCCL_COM]",
6365
var_type);
6466
}
6567
}
@@ -120,14 +122,13 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
120122

121123
for (auto& op_desc : block.AllOps()) {
122124
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
123-
VLOG(4) << op->DebugStringEx(local_scope);
124125

125126
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
126127
platform::RecordEvent record_event(op->Type(), pool.Get(place_));
127128

129+
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
128130
op->Run(*local_scope, place_);
129-
// Wait current device context.
130-
VLOG(3) << op->DebugStringEx(local_scope);
131+
131132
if (FLAGS_benchmark) {
132133
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
133134
<< memory::memory_usage(place_);

paddle/fluid/framework/framework.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ message VarType {
113113
PLACE_LIST = 14;
114114
READER = 15;
115115
CHANNEL = 16;
116+
NCCL_COM = 17;
116117
}
117118

118119
required Type type = 1;

paddle/fluid/operators/nccl_op.cc

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/framework/op_registry.h"
1616
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
17+
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
1718

1819
namespace paddle {
1920
namespace operators {
2021

22+
static constexpr char kParallelScopes[] = "parallel_scopes";
23+
2124
// NCCLinitOp
2225
class NCCLInitOp : public framework::OperatorBase {
2326
public:
@@ -29,11 +32,22 @@ class NCCLInitOp : public framework::OperatorBase {
2932
private:
3033
void RunImpl(const framework::Scope &scope,
3134
const platform::Place &place) const override {
35+
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kParallelScopes)),
36+
"Can not find variable '%s' in the scope.",
37+
kParallelScopes);
3238
const auto &name = Output("Communicator");
3339
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name),
3440
"Can not find variable '%s' in the scope.", name);
35-
std::vector<int> gpus = Attr<std::vector<int>>("gpus");
36-
PADDLE_ENFORCE(!gpus.empty(), "Attr(gpus) should not be empty.");
41+
// A parallel do may not use all the gpus. For example, the batch size is 7
42+
// in the last batch while we have 8 gpu. In this case, parallel_do will
43+
// create 7 parallel scopes, so should ncclInitOp create 7 gpu peers
44+
auto &parallel_scopes = scope.FindVar(Input(kParallelScopes))
45+
->Get<std::vector<framework::Scope *>>();
46+
std::vector<int> gpus(parallel_scopes.size());
47+
for (int i = 0; i < static_cast<int>(parallel_scopes.size()); ++i) {
48+
gpus[i] = i;
49+
}
50+
PADDLE_ENFORCE(!gpus.empty(), "NCCL init with 0 gpus.");
3751

3852
if (scope.FindVar(name) == nullptr) {
3953
PADDLE_THROW("Output(Communicator) is needed for ncclInit operator.");
@@ -45,17 +59,29 @@ class NCCLInitOp : public framework::OperatorBase {
4559
}
4660
};
4761

62+
class NCCLInitOpVarTypeInference : public framework::VarTypeInference {
63+
public:
64+
void operator()(const framework::OpDesc &op_desc,
65+
framework::BlockDesc *block) const override {
66+
auto out_var_name = op_desc.Output("Communicator").front();
67+
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
68+
auto var_type = framework::proto::VarType::NCCL_COM;
69+
out_var.SetType(var_type);
70+
}
71+
};
72+
73+
class NCCLInitOpShapeInference : public framework::InferShapeBase {
74+
public:
75+
void operator()(framework::InferShapeContext *ctx) const override {}
76+
};
77+
4878
class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
4979
public:
5080
NCCLInitOpMaker(OpProto *proto, OpAttrChecker *op_checker)
5181
: OpProtoAndCheckerMaker(proto, op_checker) {
82+
AddInput(kParallelScopes, "The working place of parallel do.");
5283
AddOutput("Communicator",
5384
"Create Communicator for communicating between gpus");
54-
AddAttr<std::vector<int>>("gpus", "(vector<int>) GPU id lists");
55-
AddAttr<int>("dtype",
56-
"(int, default 5 (FP32)) "
57-
"Output data type")
58-
.SetDefault(framework::proto::VarType::FP32);
5985
AddComment(R"DOC(
6086
NCCLInit Operator.
6187
@@ -78,7 +104,7 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel {
78104
ctx->HasInput("Communicator"),
79105
" Input(Communicator) of AllReduce op input should not be NULL");
80106
PADDLE_ENFORCE(ctx->HasOutput("Out"),
81-
" Input(X) of AllReduce op input should not be NULL");
107+
" Output(Out) of AllReduce op output should not be NULL");
82108

83109
auto x_dims = ctx->GetInputsDim("X");
84110

@@ -215,7 +241,9 @@ Bcast the tensors.
215241

216242
namespace ops = paddle::operators;
217243
REGISTER_OPERATOR(ncclInit, ops::NCCLInitOp,
218-
paddle::framework::EmptyGradOpMaker, ops::NCCLInitOpMaker);
244+
paddle::framework::EmptyGradOpMaker, ops::NCCLInitOpMaker,
245+
ops::NCCLInitOpVarTypeInference,
246+
ops::NCCLInitOpShapeInference);
219247

220248
REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp,
221249
ops::NCCLAllReduceOpMaker);

paddle/fluid/operators/parallel_do_op.cc

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ static constexpr char kOutputs[] = "outputs";
3030
static constexpr char kParallelScopes[] = "parallel_scopes";
3131

3232
static constexpr char kParallelBlock[] = "sub_block";
33+
static constexpr char kUseNCCL[] = "use_nccl";
3334

3435
using LoDTensor = framework::LoDTensor;
3536
using SelectedRows = framework::SelectedRows;
@@ -194,6 +195,8 @@ class ParallelDoOpProtoMaker : public framework::OpProtoAndCheckerMaker {
194195
AddOutput(kOutputs, "").AsDuplicable();
195196
AddOutput(kParallelScopes, "");
196197
AddAttr<framework::BlockDesc *>(kParallelBlock, "");
198+
AddAttr<bool>(kUseNCCL, "true if we use nccl on backward")
199+
.SetDefault(false);
197200
AddComment(R"DOC(
198201
ParallelDo Operator.
199202
)DOC");
@@ -216,7 +219,6 @@ class ParallelDoGradOp : public framework::OperatorBase {
216219

217220
auto &sub_scopes = scope.FindVar(Input(kParallelScopes))
218221
->Get<std::vector<framework::Scope *>>();
219-
220222
auto &places = scope.FindVar(Input(kPlaces))->Get<platform::PlaceList>();
221223

222224
// feed output@grad
@@ -243,14 +245,34 @@ class ParallelDoGradOp : public framework::OperatorBase {
243245
}
244246
WaitOnPlaces(places);
245247

246-
AccumulateGrad(scope, place, sub_scopes, places);
248+
// NCCL allreduce op will be added by backward,
249+
// so no need to explicitly accumulate grad
250+
if (!(Attr<bool>(kUseNCCL))) {
251+
AccumulateGrad(scope, place, sub_scopes, places);
252+
} else {
253+
for (auto &place : places) {
254+
PADDLE_ENFORCE(platform::is_gpu_place(place),
255+
"NCCL only supports cuda place");
256+
}
257+
}
258+
for (auto &s : Outputs(framework::GradVarName(kParameters))) {
259+
if (s == "@EMPTY@") {
260+
continue;
261+
}
262+
VLOG(3) << "Moving " << s;
263+
CopyOrShare(*sub_scopes[0]->FindVar(s), place, scope.FindVar(s));
264+
}
265+
WaitOnPlaces(places);
247266
}
248267

249268
void AccumulateGrad(const framework::Scope &scope,
250269
const platform::Place &place,
251270
const std::vector<framework::Scope *> &sub_scopes,
252271
const platform::PlaceList &places) const {
253272
for (auto &s : Outputs(framework::GradVarName(kParameters))) {
273+
if (s == "@EMPTY@") {
274+
continue;
275+
}
254276
VLOG(3) << "Accumulating " << s;
255277
if (s == framework::kEmptyVarName) continue;
256278
std::string tmp_name;

paddle/fluid/pybind/protobuf.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,8 @@ void BindVarDsec(py::module &m) {
239239
.value("LOD_RANK_TABLE", proto::VarType::LOD_RANK_TABLE)
240240
.value("LOD_TENSOR_ARRAY", proto::VarType::LOD_TENSOR_ARRAY)
241241
.value("PLACE_LIST", proto::VarType::PLACE_LIST)
242-
.value("READER", proto::VarType::READER);
242+
.value("READER", proto::VarType::READER)
243+
.value("NCCL_COM", proto::VarType::NCCL_COM);
243244
}
244245

245246
void BindOpDesc(py::module &m) {

python/paddle/v2/fluid/backward.py

Lines changed: 94 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,76 @@ def _op_can_be_removed_(op_desc, no_grad_set):
199199
return op_descs
200200

201201

202+
import proto.framework_pb2 as framework_pb2
203+
204+
205+
def serialize_op_decs(op_desc):
206+
protostr = op_desc.serialize_to_string()
207+
proto = framework_pb2.OpDesc.FromString(str(protostr))
208+
return proto.__str__()
209+
210+
211+
def _callback_lookup_(op):
212+
"""
213+
Only used in _append_backward_ops_
214+
Build and returns a callback function for certain op. For example
215+
216+
parallel_do: AllReduce
217+
218+
:param op:
219+
:return: callback function
220+
"""
221+
if op.type == 'parallel_do' and op.attr('use_nccl'):
222+
param_names = set(op.input('parameters'))
223+
param_grad_names = [n + "@GRAD" for n in param_names]
224+
225+
class ParallelDoCallBack(object):
226+
def __init__(self, param_grad_names, parallel_scopes_name):
227+
self.has_inserted_nccl_init = False
228+
self.param_grad_names = param_grad_names
229+
self.parallel_scopes_name = parallel_scopes_name
230+
231+
def __call__(self, block, context):
232+
if not self.has_inserted_nccl_init:
233+
op_desc = _create_op_desc_(
234+
"ncclInit",
235+
{"parallel_scopes": self.parallel_scopes_name},
236+
{"Communicator": ['nccl_com__do_not_change_']}, {})
237+
block.program.global_block().desc.append_op().copy_from(
238+
op_desc)
239+
self.has_inserted_nccl_init = True
240+
241+
current_op_desc = context["__current_op_desc__"]
242+
for o_param in current_op_desc.output_names():
243+
for o_argu in current_op_desc.output(o_param):
244+
if o_argu in self.param_grad_names:
245+
allreduce_out_name = o_argu + "__nccl_all_reduce__"
246+
op_desc = _create_op_desc_(
247+
"ncclAllReduce", {
248+
"X": [o_argu],
249+
"Communicator":
250+
['nccl_com__do_not_change_']
251+
}, {"Out": [allreduce_out_name]},
252+
{"reduction": "ncclSum"})
253+
block.desc.append_op().copy_from(op_desc)
254+
255+
op_desc = _create_op_desc_(
256+
"assign", {"X": [allreduce_out_name]},
257+
{"Out": [o_argu]}, {})
258+
block.desc.append_op().copy_from(op_desc)
259+
260+
return ParallelDoCallBack(param_grad_names,
261+
op.output("parallel_scopes"))
262+
else:
263+
return None
264+
265+
202266
def _append_backward_ops_(block,
203267
ops,
204268
target_block,
205269
no_grad_dict,
206270
grad_to_var,
207-
callback=None):
271+
callbacks=None):
208272
"""
209273
Create all grad ops, and insert them into given block
210274
@@ -220,14 +284,11 @@ def _append_backward_ops_(block,
220284
val(str): corresponding forward variable name
221285
callback(callable object): a callable object used to decorate new generated grad ops
222286
"""
223-
if callback is None:
224-
225-
def empty_callback(block, context):
226-
pass
227-
228-
callback = empty_callback
229-
elif not hasattr(callback, '__call__'):
230-
raise ValueError("'callback' must be a callable object.")
287+
if callbacks is not None:
288+
assert (isinstance(callbacks, list))
289+
for cb in callbacks:
290+
if not hasattr(cb, '__call__'):
291+
raise ValueError("'callback' must be a callable object.")
231292

232293
# grad_op_descs holds created grad_op, and will be appended to target_block
233294
grad_op_descs = []
@@ -238,8 +299,17 @@ def empty_callback(block, context):
238299
if op.has_attr("sub_block"):
239300
sub_block = program.block(op.block_attr("sub_block"))
240301
grad_sub_block = program.create_block(parent_idx=sub_block.idx)
241-
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
242-
no_grad_dict, grad_to_var)
302+
cb = _callback_lookup_(op)
303+
if cb is not None:
304+
if callbacks is None:
305+
new_callbacks = [cb]
306+
else:
307+
new_callbacks = callbacks + [_callback_lookup_(op)]
308+
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
309+
no_grad_dict, grad_to_var, new_callbacks)
310+
else:
311+
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
312+
no_grad_dict, grad_to_var, callbacks)
243313
grad_sub_block_list.append(grad_sub_block.desc)
244314

245315
# Getting op's corresponding grad_op
@@ -258,7 +328,11 @@ def empty_callback(block, context):
258328
for op_desc in grad_op_descs:
259329
new_op_desc = target_block.desc.append_op()
260330
new_op_desc.copy_from(op_desc)
261-
callback(block=target_block, context=grad_to_var)
331+
grad_to_var["__current_op_desc__"] = new_op_desc
332+
if callbacks is not None:
333+
assert (isinstance(callbacks, list))
334+
for cb in callbacks:
335+
cb(block=target_block, context=grad_to_var)
262336

263337

264338
def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
@@ -296,6 +370,9 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
296370
# infer_shape and infer_type
297371
op_desc.infer_var_type(block.desc)
298372
op_desc.infer_shape(block.desc)
373+
# ncclInit dones't need to set data_type
374+
if op_desc.type() == 'ncclInit':
375+
continue
299376
for arg in op_desc.output_arg_names():
300377
if arg in new_vars:
301378
_infer_var_data_type_(arg, block)
@@ -335,7 +412,8 @@ def _get_stop_gradients_(program):
335412
return no_grad_dict
336413

337414

338-
def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
415+
def append_backward(loss, parameter_list=None, no_grad_set=None,
416+
callbacks=None):
339417
"""
340418
Append backward part to main_program
341419
@@ -351,6 +429,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
351429
(list[(Variable,Variable)]): list of (parameter, gradient) pair.
352430
"""
353431
assert isinstance(loss, framework.Variable)
432+
if callbacks is not None:
433+
isinstance(callbacks, list)
354434

355435
program = loss.block.program
356436
if no_grad_set is None:
@@ -378,7 +458,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
378458
no_grad_dict[0].update(map(_append_grad_suffix_, block_no_grad_set))
379459

380460
_append_backward_ops_(root_block, op_path, root_block, no_grad_dict,
381-
grad_to_var, callback)
461+
grad_to_var, callbacks)
382462

383463
# Because calc_gradient may be called multiple times,
384464
# we need rename the internal gradient variables so that they have

python/paddle/v2/fluid/framework.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def find_name(var_list, name):
490490
'feed', 'fetch', 'save', 'load', 'recurrent',
491491
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send',
492492
'recv', 'listen_and_serv', 'parallel_do', 'save_combine',
493-
'load_combine'
493+
'load_combine', 'ncclInit'
494494
}
495495
if type not in no_kernel_op_set:
496496
self.desc.infer_var_type(self.block.desc)

0 commit comments

Comments
 (0)