Skip to content

Commit 595a2c8

Browse files
authored
explicit gradient of elementwise_add/elementwise_sub (#11970)
* "add gradient register" * "make some enhance" * "better format" * "fix typo" * "fix reuse" * "fix get expected kernel" * "change the mkldnn code" * "fix mkldnn" * "fix mkldnn failed test" * "add comment"
1 parent f37f875 commit 595a2c8

13 files changed

+405
-114
lines changed

paddle/fluid/framework/op_proto_maker.cc

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,40 @@ OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput(
4040
return OpProtoAndCheckerMaker::VariableBuilder{output};
4141
}
4242

43+
void OpProtoAndCheckerMaker::Reuse(const std::string& name,
44+
const std::string& reused_name) {
45+
bool found = false;
46+
proto::OpProto::Var* var;
47+
48+
for (auto& var : proto_->inputs()) {
49+
if (var.name() == reused_name) {
50+
found = true;
51+
break;
52+
}
53+
}
54+
PADDLE_ENFORCE(found == true,
55+
"Input/Output name: %s reused_name: %s, one of them is not "
56+
"exists or not matched.",
57+
name, reused_name);
58+
59+
found = false;
60+
for (int i = 0; i < proto_->outputs().size(); ++i) {
61+
var = proto_->mutable_outputs()->Mutable(i);
62+
if (var->name() == name) {
63+
PADDLE_ENFORCE(!var->has_reuse(),
64+
"Output(%s) has been set reused var of %s", name,
65+
var->reuse());
66+
found = true;
67+
var->set_reuse(reused_name);
68+
break;
69+
}
70+
}
71+
PADDLE_ENFORCE(found == true,
72+
"Input/Output name: %s reused_name: %s, one of them is not "
73+
"exists or not matched.",
74+
name, reused_name);
75+
}
76+
4377
void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
4478
std::unordered_set<std::string> names;
4579
auto checker = [&](const std::string& name) {

paddle/fluid/framework/op_proto_maker.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ class OpProtoAndCheckerMaker {
7878
VariableBuilder AddOutput(const std::string &name,
7979
const std::string &comment);
8080

81+
void Reuse(const std::string &name, const std::string &reused_name);
82+
8183
template <typename T>
8284
TypedAttrChecker<T> &AddAttr(const std::string &name,
8385
const std::string &comment,

paddle/fluid/framework/op_proto_maker_test.cc

Lines changed: 102 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@ TEST(ProtoMaker, DuplicatedInOut) {
4949
}
5050

5151
class TestInplaceProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
52+
public:
53+
void Make() {
54+
AddInput("X", "input of test op");
55+
AddOutput("XOut", "output of test op").Reuse("X");
56+
}
57+
};
58+
59+
class TestInplaceProtoMaker2
60+
: public paddle::framework::OpProtoAndCheckerMaker {
5261
public:
5362
void Make() {
5463
AddInput("X", "input of test op");
@@ -58,12 +67,100 @@ class TestInplaceProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
5867
};
5968

6069
TEST(ProtoMaker, InplaceOutput) {
61-
paddle::framework::proto::OpProto op_proto;
70+
paddle::framework::proto::OpProto op_proto, op_proto2;
6271
paddle::framework::OpAttrChecker op_checker;
6372
TestInplaceProtoMaker proto_maker;
64-
ASSERT_THROW(proto_maker(&op_proto, &op_checker),
73+
TestInplaceProtoMaker2 proto_maker2;
74+
75+
proto_maker(&op_proto, &op_checker);
76+
77+
ASSERT_THROW(proto_maker2(&op_proto2, &op_checker),
6578
paddle::platform::EnforceNotMet);
66-
// proto_maker(&op_proto, &op_checker);
67-
// proto_maker.Make();
68-
// ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
6979
}
80+
81+
// normal reuse
82+
class TestReuseProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
83+
public:
84+
void Make() {
85+
AddInput("X", "input of test op");
86+
AddInput("Y", "input of test op");
87+
AddOutput("Out", "output of test op");
88+
AddOutput("XOut", "output of test op");
89+
// avoid destructor exception.
90+
// Validate();
91+
TestReuse();
92+
}
93+
94+
virtual void TestReuse() {}
95+
};
96+
97+
// test duplicate reuse error
98+
class TestReuseProtoMaker2 : public TestReuseProtoMaker {
99+
public:
100+
void TestReuse() {
101+
Reuse("Out", "X");
102+
Reuse("Out", "Y");
103+
}
104+
};
105+
106+
// NotExists Input
107+
class TestReuseProtoMaker3 : public TestReuseProtoMaker {
108+
public:
109+
void TestReuse() {
110+
Reuse("Out", "NotExists");
111+
Reuse("XOut", "X");
112+
}
113+
};
114+
115+
// NotExists Output
116+
class TestReuseProtoMaker4 : public TestReuseProtoMaker {
117+
public:
118+
void TestReuse() { Reuse("NotExists", "X"); }
119+
};
120+
121+
TEST(ProtoMaker, Reuse) {
122+
paddle::framework::proto::OpProto op_proto;
123+
paddle::framework::OpAttrChecker op_checker;
124+
TestReuseProtoMaker proto_maker;
125+
proto_maker(&op_proto, &op_checker);
126+
}
127+
128+
// NOTE(dzhwinter):
129+
// There is a Fatal CHECK on base class destructor, which will call abort inside
130+
// instead of
131+
// throw an exception. If we throw an exception in Make(), we will trigger the
132+
// CHECK and terminate the tests.
133+
//
134+
// I had tried to replace the default CHECK with a exception, however, it's
135+
// still not supported by glog.
136+
// the details:
137+
// https://github.com/google/glog/issues/249
138+
// https://github.com/facebookresearch/TensorComprehensions/issues/351
139+
/*
140+
TEST(ProtoMaker, ReuseWithException) {
141+
paddle::framework::proto::OpProto op_proto2, op_proto3, op_proto4;
142+
paddle::framework::OpAttrChecker op_checker;
143+
TestReuseProtoMaker2 proto_maker2;
144+
TestReuseProtoMaker3 proto_maker3;
145+
TestReuseProtoMaker4 proto_maker4;
146+
EXPECT_THROW(proto_maker2(&op_proto2, &op_checker),
147+
paddle::platform::EnforceNotMet);
148+
149+
EXPECT_THROW(proto_maker3(&op_proto3, &op_checker),
150+
paddle::platform::EnforceNotMet);
151+
152+
EXPECT_THROW(proto_maker4(&op_proto4, &op_checker),
153+
paddle::platform::EnforceNotMet);
154+
}
155+
156+
void FailureFunction() {
157+
throw std::runtime_error("Check failed in destructor.");
158+
// return 0;
159+
}
160+
161+
int main(int argc, char** argv) {
162+
testing::InitGoogleTest(&argc, argv);
163+
google::InstallFailureFunction(&FailureFunction);
164+
return RUN_ALL_TESTS();
165+
}
166+
*/

paddle/fluid/operators/elementwise_add_mkldnn_op.cc

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
4747
int axis = ctx.Attr<int>("axis");
4848

4949
auto x_dims = x->dims();
50-
auto y_dims = y->dims();
50+
auto y_dims_untrimed = y->dims();
5151
auto z_dims = z->dims();
5252

5353
// Execute default elementwise_add operator when
5454
// broadcast operations need to performed.
55-
if (x_dims != y_dims) {
55+
if (x_dims != y_dims_untrimed) {
5656
auto sum_func = [](T a, T b) -> T { return a + b; };
5757

5858
TransformFunctor<decltype(sum_func), T,
@@ -62,11 +62,11 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
6262
ctx.template device_context<paddle::platform::CPUDeviceContext>(),
6363
sum_func);
6464

65-
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
65+
axis = (axis == -1 ? x_dims.size() - y_dims_untrimed.size() : axis);
6666
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
6767
"Axis should be in range [0, x_dims)");
6868

69-
trim_trailing_singular_dims(&y_dims);
69+
auto y_dims = trim_trailing_singular_dims(y_dims_untrimed);
7070
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
7171

7272
int pre, n, post;
@@ -88,7 +88,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
8888
"Wrong layout/format set for Y tensor");
8989

9090
std::vector<int> src_x_tz = framework::vectorize2int(x_dims);
91-
std::vector<int> src_y_tz = framework::vectorize2int(y_dims);
91+
std::vector<int> src_y_tz = framework::vectorize2int(y_dims_untrimed);
9292
std::vector<int> dst_tz = framework::vectorize2int(z_dims);
9393

9494
std::vector<memory::primitive_desc> srcs_pd;
@@ -142,36 +142,39 @@ class EltwiseAddMKLDNNGradKernel : public framework::OpKernel<T> {
142142
void Compute(const framework::ExecutionContext& ctx) const override {
143143
using Tensor = framework::Tensor;
144144

145-
auto* x = ctx.Input<Tensor>("X");
146-
auto* y = ctx.Input<Tensor>("Y");
147-
auto* out = ctx.Input<Tensor>("Out");
148145
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
149146
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
150147
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
151148
int axis = ctx.Attr<int>("axis");
149+
// skip out, x, y,
150+
// dout length is larger or equal than dx, dy.
151+
auto* out = dout;
152+
auto *x = dout, *y = dout;
152153

153154
auto set_mkldnn_format = [](Tensor* in, const Tensor* out) {
154155
in->set_layout(DataLayout::kMKLDNN);
155156
in->set_format(out->format());
156157
};
157158

158-
if (x->dims() == y->dims()) {
159-
auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx);
160-
if (dx) {
161-
blas.VCOPY(dout->numel(), dout->data<T>(),
162-
dx->mutable_data<T>(ctx.GetPlace()));
163-
set_mkldnn_format(dx, dout);
164-
}
165-
166-
if (dy) {
167-
blas.VCOPY(dout->numel(), dout->data<T>(),
168-
dy->mutable_data<T>(ctx.GetPlace()));
169-
set_mkldnn_format(dy, dout);
159+
if (dx != nullptr && dy != nullptr && dx->dims() == dy->dims()) {
160+
if (dx->dims() == dy->dims()) {
161+
auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx);
162+
if (dx) {
163+
blas.VCOPY(dout->numel(), dout->data<T>(),
164+
dx->mutable_data<T>(ctx.GetPlace()));
165+
set_mkldnn_format(dx, dout);
166+
}
167+
168+
if (dy) {
169+
blas.VCOPY(dout->numel(), dout->data<T>(),
170+
dy->mutable_data<T>(ctx.GetPlace()));
171+
set_mkldnn_format(dy, dout);
172+
}
170173
}
171174
} else {
172175
// Execute default kernel when broadcast is needed
173-
ElemwiseGradCompute<paddle::platform::CPUDeviceContext, T,
174-
IdentityGrad<T>, IdentityGrad<T>>(
176+
ElemwiseExplicitGradCompute<paddle::platform::CPUDeviceContext, T,
177+
IdentityGrad<T>, IdentityGrad<T>>(
175178
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
176179
IdentityGrad<T>());
177180
}

paddle/fluid/operators/elementwise_add_op.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ limitations under the License. */
1515
#include "paddle/fluid/operators/elementwise_add_op.h"
1616
#include "paddle/fluid/operators/elementwise_op.h"
1717
namespace ops = paddle::operators;
18-
REGISTER_ELEMWISE_OP(elementwise_add, "Add", "Out = X + Y");
18+
REGISTER_ELEMWISE_GRAD_MAKER(elementwise_add, Add);
19+
REGISTER_ELEMWISE_EXPLICIT_OP(elementwise_add, "Add", "Out = X + Y", "Out",
20+
"X");
1921
REGISTER_OP_CPU_KERNEL(
2022
elementwise_add,
2123
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>,

paddle/fluid/operators/elementwise_add_op.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,10 @@ void default_elementwise_add_grad(const framework::ExecutionContext& ctx,
9595
framework::Tensor* dy) {
9696
int axis = ctx.Attr<int>("axis");
9797

98-
ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
99-
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
100-
IdentityGrad<T>());
98+
ElemwiseExplicitGradCompute<DeviceContext, T, IdentityGrad<T>,
99+
IdentityGrad<T>>(ctx, *x, *y, *out, *dout, axis,
100+
dx, dy, IdentityGrad<T>(),
101+
IdentityGrad<T>());
101102
}
102103

103104
template <typename DeviceContext, typename T>
@@ -140,14 +141,15 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
140141
void Compute(const framework::ExecutionContext& ctx) const override {
141142
using Tensor = framework::Tensor;
142143

143-
auto* x = ctx.Input<Tensor>("X");
144-
auto* y = ctx.Input<Tensor>("Y");
145-
auto* out = ctx.Input<Tensor>("Out");
146144
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
147145
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
148146
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
147+
// skip out, x, y
148+
auto* out = dout;
149+
auto *x = dout, *y = dout;
149150

150-
if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) {
151+
if (platform::is_cpu_place(ctx.GetPlace()) && dx != nullptr &&
152+
dy != nullptr && (dx->dims() == dy->dims())) {
151153
elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
152154
} else {
153155
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,

paddle/fluid/operators/elementwise_div_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ limitations under the License. */
1515
#include "paddle/fluid/operators/elementwise_div_op.h"
1616
#include "paddle/fluid/operators/elementwise_op.h"
1717
namespace ops = paddle::operators;
18+
1819
REGISTER_ELEMWISE_OP(elementwise_div, "Div", "Out = X / Y");
20+
1921
REGISTER_OP_CPU_KERNEL(
2022
elementwise_div,
2123
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, float>,

0 commit comments

Comments
 (0)