Skip to content

Commit 8653cf3

Browse files
authored
Merge pull request #10656 from reyoung/feature/support_op_role
Add `op_role` into OpDesc.
2 parents 16b09d3 + 50dab46 commit 8653cf3

File tree

15 files changed

+290
-100
lines changed

15 files changed

+290
-100
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "paddle/fluid/framework/details/reduce_op_handle.h"
1919
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
2020
#include "paddle/fluid/framework/details/send_op_handle.h"
21+
#include "paddle/fluid/framework/op_info.h"
2122
#include "paddle/fluid/framework/scope.h"
2223

2324
#ifdef PADDLE_WITH_CUDA
@@ -159,25 +160,39 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
159160
if (!is_forwarding && places_.size() > 1) {
160161
// Currently, we assume that once gradient is generated, it can be
161162
// broadcast, and each gradient is only broadcast once.
162-
for (auto &og : op->OutputArgumentNames()) {
163-
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
164-
switch (strategy_.reduce_) {
165-
case BuildStrategy::ReduceStrategy::kReduce:
166-
CreateReduceOp(&result, og, cur_device_id);
167-
var_name_on_devices[cur_device_id].emplace(og);
168-
bcast_var_name_set[cur_device_id].emplace(
169-
og.substr(0, og.size() - strlen(kGradVarSuffix)));
170-
cur_device_id = (cur_device_id + 1) % places_.size();
171-
break;
172-
case BuildStrategy::ReduceStrategy::kAllReduce:
173-
if (IsSparseGradient(var_types, og)) {
174-
CreateReduceOp(&result, og, 0);
175-
CreateBroadcastOp(&result, og, 0);
176-
} else {
177-
InsertNCCLAllReduceOp(&result, og);
178-
}
179-
break;
163+
if (static_cast<bool>(boost::get<int>(op->GetAttr(
164+
OpProtoAndCheckerMaker::OpRoleAttrName())) &
165+
static_cast<int>(OpRole::kBackward))) {
166+
try {
167+
auto backward_vars =
168+
boost::get<std::vector<std::string>>(op->GetNullableAttr(
169+
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
170+
171+
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
172+
173+
for (size_t i = 0; i < backward_vars.size(); i += 2) {
174+
auto &p_name = backward_vars[i];
175+
auto &g_name = backward_vars[i + 1];
176+
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
177+
178+
switch (strategy_.reduce_) {
179+
case BuildStrategy::ReduceStrategy::kReduce:
180+
CreateReduceOp(&result, g_name, cur_device_id);
181+
var_name_on_devices[cur_device_id].emplace(g_name);
182+
bcast_var_name_set[cur_device_id].emplace(p_name);
183+
cur_device_id = (cur_device_id + 1) % places_.size();
184+
break;
185+
case BuildStrategy::ReduceStrategy::kAllReduce:
186+
if (IsSparseGradient(var_types, g_name)) {
187+
CreateReduceOp(&result, g_name, 0);
188+
CreateBroadcastOp(&result, g_name, 0);
189+
} else {
190+
InsertNCCLAllReduceOp(&result, g_name);
191+
}
192+
break;
193+
}
180194
}
195+
} catch (boost::bad_get e) {
181196
}
182197
}
183198
}
@@ -398,11 +413,12 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
398413
}
399414

400415
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
401-
// FIXME(yy): Do not hard code like this
402-
return op.OutputArgumentNames().size() == 1 &&
403-
op.OutputArgumentNames()[0] == GradVarName(loss_var_name_);
416+
return boost::get<int>(
417+
op.GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
418+
(static_cast<int>(OpRole::kBackward) |
419+
static_cast<int>(OpRole::kLoss)) &&
420+
!loss_var_name_.empty(); // If loss_var is empty. This is test mode
404421
}
405-
406422
} // namespace details
407423
} // namespace framework
408424
} // namespace paddle

paddle/fluid/framework/details/op_registry.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,7 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
9696
info->proto_ = new proto::OpProto;
9797
info->checker_ = new OpAttrChecker();
9898
T maker;
99-
maker.SetProto(info->proto_);
100-
maker.SetChecker(info->checker_);
101-
maker.Make();
102-
maker.Validate();
99+
maker(info->proto_, info->checker_);
103100
info->proto_->set_type(op_type);
104101
PADDLE_ENFORCE(
105102
info->proto_->IsInitialized(),

paddle/fluid/framework/op_desc.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License. */
2020
#include <unordered_map>
2121
#include "glog/logging.h"
2222
#include "paddle/fluid/framework/block_desc.h"
23+
#include "paddle/fluid/framework/op_proto_maker.h"
2324
#include "paddle/fluid/framework/operator.h"
2425
#include "paddle/fluid/framework/program_desc.h"
2526
#include "paddle/fluid/framework/shape_inference.h"
@@ -222,6 +223,15 @@ Attribute OpDesc::GetAttr(const std::string &name) const {
222223
return it->second;
223224
}
224225

226+
Attribute OpDesc::GetNullableAttr(const std::string &name) const {
227+
auto it = attrs_.find(name);
228+
if (it != attrs_.end()) {
229+
return it->second;
230+
} else {
231+
return Attribute();
232+
}
233+
}
234+
225235
int OpDesc::GetBlockAttr(const std::string &name) const {
226236
auto it = attrs_.find(name);
227237
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
@@ -249,6 +259,13 @@ void OpDesc::RenameOutput(const std::string &old_name,
249259
std::replace(output.second.begin(), output.second.end(), old_name,
250260
new_name);
251261
}
262+
263+
auto it = attrs_.find(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName());
264+
if (it != attrs_.end()) {
265+
auto &op_vars = boost::get<std::vector<std::string>>(it->second);
266+
std::replace(op_vars.begin(), op_vars.end(), old_name, new_name);
267+
}
268+
252269
need_update_ = true;
253270
}
254271

paddle/fluid/framework/op_desc.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ class OpDesc {
7878

7979
Attribute GetAttr(const std::string &name) const;
8080

81+
Attribute GetNullableAttr(const std::string &name) const;
82+
8183
int GetBlockAttr(const std::string &name) const;
8284

8385
void Rename(const std::string &old_name, const std::string &new_name);

paddle/fluid/framework/op_proto_maker.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ limitations under the License. */
1313

1414
#include "paddle/fluid/framework/op_proto_maker.h"
1515
#include <string>
16+
#include <vector>
1617

1718
namespace paddle {
1819
namespace framework {
@@ -55,5 +56,28 @@ void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
5556
}
5657
}
5758

59+
void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
60+
OpAttrChecker* attr_checker) {
61+
proto_ = proto;
62+
op_checker_ = attr_checker;
63+
Make();
64+
65+
AddAttr<int>(OpRoleAttrName(), "The role of this operator")
66+
.InEnum(
67+
{static_cast<int>(OpRole::kForward),
68+
static_cast<int>(OpRole::kBackward),
69+
static_cast<int>(OpRole::kOptimize),
70+
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward),
71+
static_cast<int>(OpRole::kLoss) |
72+
static_cast<int>(OpRole::kBackward),
73+
static_cast<int>(OpRole::kNotSpecified)})
74+
.SetDefault(static_cast<int>(OpRole::kNotSpecified));
75+
AddAttr<std::vector<std::string>>(OpRoleVarAttrName(),
76+
"Optimized for variable")
77+
.SetDefault({});
78+
79+
Validate();
80+
}
81+
5882
} // namespace framework
5983
} // namespace paddle

paddle/fluid/framework/op_proto_maker.h

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,31 @@ limitations under the License. */
2020
namespace paddle {
2121
namespace framework {
2222

23+
enum class OpRole {
24+
kForward = 0x0000,
25+
kBackward = 0x0001,
26+
kOptimize = 0x0002,
27+
28+
kLoss = 0x0100,
29+
// The default value of op's role. This should be only used for unittests and
30+
// CreateOp inside a operator.
31+
kNotSpecified = 0x1000,
32+
};
33+
2334
// this class not only make proto but also init attribute checkers.
2435
class OpProtoAndCheckerMaker {
2536
public:
37+
static const char *OpRoleAttrName() { return "op_role"; }
38+
static const char *OpRoleVarAttrName() { return "op_role_var"; }
39+
40+
void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker);
41+
2642
virtual void Make() = 0;
2743

2844
virtual ~OpProtoAndCheckerMaker() {
2945
CHECK(validated_) << "should call Validate after build";
3046
}
3147

32-
void SetProto(proto::OpProto *proto) { proto_ = proto; }
33-
34-
void SetChecker(OpAttrChecker *attr_checker) { op_checker_ = attr_checker; }
35-
36-
void Validate();
37-
3848
protected:
3949
struct VariableBuilder {
4050
proto::OpProto::Var *var_;
@@ -76,6 +86,7 @@ class OpProtoAndCheckerMaker {
7686

7787
private:
7888
void CheckNoDuplicatedInOutAttrs();
89+
void Validate();
7990

8091
proto::OpProto *proto_;
8192
OpAttrChecker *op_checker_;

paddle/fluid/framework/op_proto_maker_test.cc

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,8 @@ TEST(ProtoMaker, DuplicatedAttr) {
2828
paddle::framework::proto::OpProto op_proto;
2929
paddle::framework::OpAttrChecker op_checker;
3030
TestAttrProtoMaker proto_maker;
31-
proto_maker.SetProto(&op_proto);
32-
proto_maker.SetChecker(&op_checker);
33-
proto_maker.Make();
34-
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
31+
ASSERT_THROW(proto_maker(&op_proto, &op_checker),
32+
paddle::platform::EnforceNotMet);
3533
}
3634

3735
class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
@@ -46,8 +44,6 @@ TEST(ProtoMaker, DuplicatedInOut) {
4644
paddle::framework::proto::OpProto op_proto;
4745
paddle::framework::OpAttrChecker op_checker;
4846
TestAttrProtoMaker proto_maker;
49-
proto_maker.SetProto(&op_proto);
50-
proto_maker.SetChecker(&op_checker);
51-
proto_maker.Make();
52-
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
47+
ASSERT_THROW(proto_maker(&op_proto, &op_checker),
48+
paddle::platform::EnforceNotMet);
5349
}

paddle/fluid/pybind/const_value.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/pybind/const_value.h"
16+
#include <paddle/fluid/framework/op_proto_maker.h>
1617
#include "paddle/fluid/framework/operator.h"
1718

1819
namespace paddle {
@@ -23,6 +24,21 @@ void BindConstValue(pybind11::module* m) {
2324
m->def("kTempVarName", [] { return framework::kTempVarName; });
2425
m->def("kGradVarSuffix", [] { return framework::kGradVarSuffix; });
2526
m->def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; });
27+
28+
auto op_proto_and_checker_maker =
29+
m->def_submodule("op_proto_and_checker_maker");
30+
31+
pybind11::enum_<framework::OpRole>(op_proto_and_checker_maker, "OpRole")
32+
.value("Forward", framework::OpRole::kForward)
33+
.value("Backward", framework::OpRole::kBackward)
34+
.value("Optimize", framework::OpRole::kOptimize)
35+
.value("Loss", framework::OpRole::kLoss);
36+
37+
op_proto_and_checker_maker.def(
38+
"kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName);
39+
op_proto_and_checker_maker.def(
40+
"kOpRoleVarAttrName",
41+
framework::OpProtoAndCheckerMaker::OpRoleVarAttrName);
2642
}
2743

2844
} // namespace pybind

python/paddle/fluid/backward.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
5151
op_desc.set_input(para, args)
5252
for para, args in outputs.iteritems():
5353
op_desc.set_output(para, args)
54+
55+
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
56+
57+
if op_role_attr_name not in attrs:
58+
attrs[
59+
op_role_attr_name] = core.op_proto_and_checker_maker.OpRole.Backward
5460
for name, val in attrs.iteritems():
5561
if isinstance(val, framework.Block):
5662
op_desc.set_block_attr(name, val.desc)
@@ -141,15 +147,15 @@ def _addup_repetitive_outputs_(op_descs):
141147
else:
142148
if len(renamed_vars[var_name]) == 1:
143149
new_name = var_name + "@RENAME@" + \
144-
str(var_rename_count[var_name])
150+
str(var_rename_count[var_name])
145151
var_rename_count[var_name] += 1
146152
# rename original var_name
147153
renamed_vars[var_name][0] = new_name
148154
_rename_arg_(op_descs, var_name, new_name, 0, idx)
149155
_rename_arg_(pending_sum_ops, var_name, new_name)
150156

151157
new_name = var_name + "@RENAME@" + \
152-
str(var_rename_count[var_name])
158+
str(var_rename_count[var_name])
153159
var_rename_count[var_name] += 1
154160
op_desc.rename_output(var_name, new_name)
155161
renamed_vars[var_name].append(new_name)
@@ -335,9 +341,12 @@ def _append_backward_ops_(block,
335341
no_grad_dict[block.idx])
336342

337343
# append op_desc in grad_op_descs to target_block
344+
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
345+
backward = core.op_proto_and_checker_maker.OpRole.Backward
338346
for op_desc in grad_op_descs:
339347
new_op_desc = target_block.desc.append_op()
340348
new_op_desc.copy_from(op_desc)
349+
new_op_desc.set_attr(op_role_attr_name, backward)
341350
grad_to_var["__current_op_desc__"] = new_op_desc
342351
if callbacks is not None:
343352
assert (isinstance(callbacks, list))
@@ -439,6 +448,22 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
439448
(list[(Variable,Variable)]): list of (parameter, gradient) pair.
440449
"""
441450
assert isinstance(loss, framework.Variable)
451+
452+
if loss.op is None:
453+
# the loss is from a cloned program. Find loss op manually.
454+
for op in reversed(loss.block.ops):
455+
assert isinstance(op, framework.Operator)
456+
if len(op.output_arg_names) == 1 and op.output_arg_names[
457+
0] == loss.name:
458+
loss.op = op
459+
break
460+
if loss.op is None:
461+
raise ValueError("loss.op is None. Should not happend")
462+
463+
loss.op.set_attr(core.op_proto_and_checker_maker.kOpRoleAttrName(),
464+
int(core.op_proto_and_checker_maker.OpRole.Forward) |
465+
int(core.op_proto_and_checker_maker.OpRole.Loss))
466+
442467
if callbacks is not None:
443468
isinstance(callbacks, list)
444469

@@ -456,12 +481,16 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
456481
current_block_idx = program.current_block_idx
457482
grad_to_var = dict()
458483

459-
op_desc = _create_op_desc_("fill_constant", {}, {
460-
"Out": [_append_grad_suffix_(loss.name)]
461-
}, {"shape": [1],
462-
"value": 1.0,
463-
"dtype": loss.dtype,
464-
"force_cpu": False})
484+
op_desc = _create_op_desc_(
485+
"fill_constant", {}, {"Out": [_append_grad_suffix_(loss.name)]}, {
486+
"shape": [1],
487+
"value": 1.0,
488+
"dtype": loss.dtype,
489+
"force_cpu": False,
490+
core.op_proto_and_checker_maker.kOpRoleAttrName():
491+
int(core.op_proto_and_checker_maker.OpRole.Backward) |
492+
int(core.op_proto_and_checker_maker.OpRole.Loss),
493+
})
465494
root_block.desc.append_op().copy_from(op_desc)
466495

467496
block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0]))
@@ -505,6 +534,24 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
505534
params_and_grads.append((param_var, grad_var))
506535
else:
507536
params_and_grads.append((param_var, None))
537+
538+
op_role_var_attr_name = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
539+
for p, g in params_and_grads:
540+
if g is None:
541+
continue
542+
for op in reversed(program.global_block().ops):
543+
assert isinstance(op, framework.Operator)
544+
if g.name in op.output_arg_names:
545+
g.op = op
546+
break
547+
548+
if g.op is None:
549+
raise ValueError("Unexpected branch")
550+
attr_val = [p.name, g.name]
551+
if g.op.has_attr(op_role_var_attr_name):
552+
attr_val.extend(g.op.attr(op_role_var_attr_name))
553+
g.op.set_attr(op_role_var_attr_name, attr_val)
554+
508555
return params_and_grads
509556

510557

0 commit comments

Comments
 (0)