Skip to content

Commit 3bd1d22

Browse files
author
chengduo
authored
Enhance fused_elementwise_activation_op (#12837)
* Enhance the function of fused_elementwise_activation_op * enhance unit test * Clean Code And Add Doc * Add compound functors * Fix doc and enhance unit test * define Dx and Dy for d_binary_func * add mul_scale * add mul_scale * add elementwise_mul * code refine * code refine * add doc * add AsIntermediate
1 parent a615ad4 commit 3bd1d22

File tree

7 files changed

+1897
-1257
lines changed

7 files changed

+1897
-1257
lines changed

paddle/fluid/operators/elementwise_op_function.h

Lines changed: 928 additions & 104 deletions
Large diffs are not rendered by default.

paddle/fluid/operators/fused_elemwise_activation_op.cc

Lines changed: 181 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,60 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include "paddle/fluid/operators/fused_elemwise_activation_op.h"
1516
#include <string>
1617
#include <vector>
1718

18-
#include "paddle/fluid/operators/fused_elemwise_activation_op.h"
19-
2019
namespace paddle {
2120
namespace operators {
2221

22+
/*
23+
* Whether the compound function is Unary(Binary(X, Y)).
24+
* For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final
25+
* out.
26+
*/
27+
static bool IsUnaryCompound(const std::vector<std::string> &functor_list) {
28+
PADDLE_ENFORCE_EQ(functor_list.size(), 2);
29+
static std::unordered_set<std::string> binary_fun = {
30+
"elementwise_add", "elementwise_mul", "elementwise_add_grad",
31+
"elementwise_mul_grad"};
32+
return binary_fun.count(functor_list[1]) != 0;
33+
}
34+
35+
/*
36+
* Whether the Input(X) could be absent.
37+
*/
38+
static bool InputXCanBeAbsent(const std::vector<std::string> &functor_list) {
39+
PADDLE_ENFORCE_EQ(functor_list.size(), 2);
40+
static std::unordered_set<std::string> binary_fun = {"elementwise_add_grad"};
41+
return binary_fun.count(functor_list[0]) != 0 ||
42+
binary_fun.count(functor_list[1]) != 0;
43+
}
44+
45+
/*
46+
* Whether the compound function is supported.
47+
* For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final
48+
* out.
49+
*/
50+
static bool IsSupportedCompound(const std::vector<std::string> &functors) {
51+
static std::unordered_set<std::string> unary_fun = {"scale", "relu"};
52+
static std::unordered_set<std::string> binary_fun = {"elementwise_add",
53+
"elementwise_mul"};
54+
55+
std::string unary_fun_str;
56+
if (binary_fun.count(functors[0])) {
57+
unary_fun_str = functors[1];
58+
} else if (binary_fun.count(functors[1])) {
59+
unary_fun_str = functors[0];
60+
} else {
61+
PADDLE_THROW("%s and %s are not included in fused_list.", functors[0],
62+
functors[1]);
63+
}
64+
PADDLE_ENFORCE_EQ(unary_fun.count(unary_fun_str), 1,
65+
"%s is not included in fused_list.", unary_fun_str);
66+
return true;
67+
}
68+
2369
class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
2470
public:
2571
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -37,11 +83,44 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
3783

3884
auto x_dim = ctx->GetInputDim("X");
3985
auto y_dim = ctx->GetInputDim("Y");
40-
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
41-
"Rank of first input must >= rank of second input.");
4286

43-
ctx->SetOutputDim("Out", x_dim);
44-
ctx->ShareLoD("X", /*->*/ "Out");
87+
// Whether the shape of Y is a continuous subsequence of X,
88+
// For more information please refer to the op's introduction.
89+
bool bcast_y = x_dim.size() >= y_dim.size();
90+
if (x_dim.size() == y_dim.size()) {
91+
for (int i = 0; i < x_dim.size(); ++i) {
92+
if (x_dim[i] < y_dim[i]) {
93+
bcast_y = false;
94+
break;
95+
}
96+
}
97+
}
98+
99+
auto &out_dim = bcast_y ? x_dim : y_dim;
100+
std::string out_lod = bcast_y ? "X" : "Y";
101+
102+
if (ctx->Attrs().Get<bool>("keep_intermediate_value")) {
103+
PADDLE_ENFORCE(ctx->HasOutput("IntermediateOut"),
104+
"Output(IntermediateOut) of FusedElemwiseActivationOp "
105+
"should not be null.");
106+
107+
if (IsUnaryCompound(
108+
ctx->Attrs().Get<std::vector<std::string>>("functor_list"))) {
109+
// for Unary(Binary(X, Y)), the shape and lod of out and
110+
// intermediate_out are the same.
111+
ctx->SetOutputDim("IntermediateOut", out_dim);
112+
// set the lod of intermediate_out
113+
ctx->ShareLoD(out_lod, /*->*/ "IntermediateOut");
114+
} else {
115+
// for Binary(X, Unary(Y)), the shape and lod of Y and
116+
// intermediate_out are the same.
117+
ctx->SetOutputDim("IntermediateOut", y_dim);
118+
// set the lod of intermediate_out
119+
ctx->ShareLoD("Y", /*->*/ "IntermediateOut");
120+
}
121+
}
122+
ctx->SetOutputDim("Out", out_dim);
123+
ctx->ShareLoD(out_lod, /*->*/ "Out");
45124
}
46125

47126
protected:
@@ -59,29 +138,42 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
59138
class FusedElemwiseActivationMaker : public framework::OpProtoAndCheckerMaker {
60139
public:
61140
void Make() override {
62-
AddInput("X", "(vector<Tensor>)");
63-
AddInput("Y", "(vector<Tensor>)");
64-
AddOutput("Out", "vector<Tensor>");
141+
AddInput(
142+
"X",
143+
"(Tensor) The input tensor of fused_elemwise_activation operator.");
144+
AddInput(
145+
"Y",
146+
"(Tensor) The input tensor of fused_elemwise_activation operator.");
147+
AddOutput("Out",
148+
"vector<Tensor> The output tensor of fused_elemwise_activation "
149+
"operator.");
150+
AddOutput("IntermediateOut",
151+
"Tensor The IntermediateOut tensor of fused_elemwise_activation "
152+
"operator.")
153+
.AsIntermediate();
65154
AddAttr<int>("axis",
66155
"axis is used by elementwise_op, the default value is -1.")
67156
.SetDefault(-1);
68157
AddAttr<float>("scale",
69158
"scale is used by scale_op, the default value is 0.0.")
70159
.SetDefault(0.0);
71-
AddAttr<bool>("recomputation",
72-
"Whether to recompute the Out."
73-
"fused_elemwise_activation_grad has two methods to get the "
74-
"dx and dy, one "
75-
"is to use the 'Out', and the other is not to use it. "
76-
"The former method will save the time of recomputing the "
77-
"'Out', but it must occupy the memory to store the 'out'. "
78-
"While, the later method can avoid occupying the memory, "
79-
"but it must recompute the 'Out'. The default value is true.")
160+
AddAttr<bool>(
161+
"recomputation",
162+
"Whether to recompute the Out."
163+
"The computation of fused_elemwise_activation_grad has two methods to "
164+
"get the dx and dy, one is to use the 'Out', and the other is not. "
165+
"The former method will save the time of recomputing the 'Out', but it "
166+
"must occupy the memory to store the 'out'. While, the later method "
167+
"can avoid occupying the memory, but it must recompute the 'Out'. "
168+
"It is useful for Unary(Binary(X, Y)). The default value is true.")
80169
.SetDefault(true);
170+
AddAttr<bool>("keep_intermediate_value",
171+
"Whether to save the intermediate_out.")
172+
.SetDefault(false);
81173
AddAttr<std::vector<std::string>>("functor_list",
82174
"The functors that should be fused.")
83175
.AddCustomChecker([&](const std::vector<std::string> &functor_list) {
84-
PADDLE_ENFORCE(ValidCheck(functor_list));
176+
PADDLE_ENFORCE(IsSupportedCompound(functor_list));
85177
});
86178

87179
AddComment(R"DOC(
@@ -93,30 +185,38 @@ operators (elementwise_op and activation_op):
93185
Z = Binary(X, Unary(Y))
94186
Z = Unary(Binary(X, Y))
95187
96-
The attributions of activation_op can be get from fused_elemwise_activation_op's
97-
attributions. functor_list records the functors to be fused, for example
98-
"scale,elementwise_add".
188+
There are two cases for this operator:
99189
100-
)DOC");
101-
}
190+
1. The shape of $Y$ and $X$ is the same.
191+
2. The shape of $Y$ is a continuous subsequence of $X$ or the shape of $X$ is a continuous subsequence of $Y$.
102192
103-
private:
104-
bool ValidCheck(const std::vector<std::string> &functors) {
105-
std::unordered_set<std::string> unary_fun = {"scale", "relu"};
106-
std::unordered_set<std::string> binary_fun = {"elementwise_add"};
193+
For case 2 (assume that the shape of $Y$ is a continuous subsequence of $X$ ):
107194
108-
std::string unary_fun_str;
109-
if (binary_fun.count(functors[0])) {
110-
unary_fun_str = functors[1];
111-
} else if (binary_fun.count(functors[1])) {
112-
unary_fun_str = functors[0];
113-
} else {
114-
PADDLE_THROW("%s and %s are not included in fused_list.", functors[0],
115-
functors[1]);
116-
}
117-
PADDLE_ENFORCE_EQ(unary_fun.count(unary_fun_str), 1,
118-
"%s is not included in fused_list.", unary_fun_str);
119-
return true;
195+
1. Broadcast $Y$ to match the shape of $X$, where $axis$ is the start dimension index
196+
for broadcasting $Y$ onto $X$.
197+
2. If $axis$ is -1 (default), $axis = rank(X) - rank(Y)$.
198+
3. The trailing dimensions of size 1 for $Y$ will be ignored for the consideration of
199+
subsequence, such as shape(Y) = (2, 1) => (2).
200+
201+
For example:
202+
203+
.. code-block:: python
204+
205+
shape(X) = (2, 3, 4, 5), shape(Y) = (,)
206+
shape(X) = (2, 3, 4, 5), shape(Y) = (5,)
207+
shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5), with axis=-1(default) or axis=2
208+
shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
209+
shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
210+
shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0
211+
212+
213+
The inputs $X$ and $Y$ can carry the different LoD information.
214+
But the output only shares the LoD information with the one whose shape is the same with Out.
215+
The attributions of activation_op can be get from fused_elemwise_activation_op's.
216+
The functor_list records the functions to be fused, for example
217+
["scale", "elementwise_add"].
218+
219+
)DOC");
120220
}
121221
};
122222

@@ -141,6 +241,7 @@ class FusedElemwiseActivationGradMaker
141241
op_desc_ptr->SetInput(framework::GradVarName(output_param),
142242
this->OutputGrad(output_param));
143243
}
244+
144245
op_desc_ptr->SetAttrMap(this->Attrs());
145246

146247
std::vector<std::string> functor_names =
@@ -158,40 +259,59 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
158259
using framework::OperatorWithKernel::OperatorWithKernel;
159260

160261
void InferShape(framework::InferShapeContext *ctx) const override {
161-
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
162-
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
163262
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
164-
"Input(Out@GRAD) should not be null");
165-
166-
auto x_dims = ctx->GetInputDim("X");
167-
auto y_dims = ctx->GetInputDim("Y");
168-
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
169-
170-
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
171-
"Rank of first input must >= rank of second input.");
263+
"Input(Out@Grad) should not be null");
264+
if (ctx->Attrs().Get<bool>("keep_intermediate_value")) {
265+
PADDLE_ENFORCE(ctx->HasInput("IntermediateOut"),
266+
"Input(IntermediateOut) should not be null");
267+
} else {
268+
PADDLE_ENFORCE_EQ(ctx->Inputs(framework::GradVarName("Out")).size(), 1);
269+
}
172270

271+
auto funtor_list =
272+
ctx->Attrs().Get<std::vector<std::string>>("functor_list");
173273
auto x_grad_name = framework::GradVarName("X");
174274
auto y_grad_name = framework::GradVarName("Y");
275+
175276
if (ctx->HasOutput(x_grad_name)) {
176-
ctx->SetOutputDim(x_grad_name, x_dims);
277+
if (ctx->HasInputs("X")) {
278+
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
279+
ctx->ShareLoD("X", x_grad_name);
280+
} else {
281+
// Node: If "X" is absence, the shape of Y should be a continuous
282+
// subsequence of X, if not, we could not infer the shape of dx.
283+
284+
// Currently, only when Binary is elementwise_add or elementwise_sub,
285+
// the "X" could be absent.
286+
PADDLE_ENFORCE(InputXCanBeAbsent(funtor_list),
287+
"Only when BinaryFunctor is elementwise_add, the 'X' "
288+
"could be absent.");
289+
290+
// For Unary(Binary(X, Y)), IntermediateOut should not be empty.
291+
if (IsUnaryCompound(funtor_list)) {
292+
PADDLE_ENFORCE(
293+
ctx->HasInputs("IntermediateOut"),
294+
"If the compound_functor is Unary(Binary(X, Y)) and Binary "
295+
"is elementwise_add, the intermediate_out must be not absent.");
296+
}
297+
298+
ctx->SetOutputDim(x_grad_name,
299+
ctx->GetInputDim(framework::GradVarName("Out")));
300+
ctx->ShareLoD(framework::GradVarName("Out"), x_grad_name);
301+
}
177302
}
178303
if (ctx->HasOutput(y_grad_name)) {
179-
ctx->SetOutputDim(y_grad_name, y_dims);
304+
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
305+
ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y"));
306+
ctx->ShareLoD("Y", y_grad_name);
180307
}
181308
}
182309

183310
protected:
184311
framework::OpKernelType GetExpectedKernelType(
185312
const framework::ExecutionContext &ctx) const override {
186-
auto input_data_type_index = ctx.Input<framework::Tensor>("X")->type();
187-
PADDLE_ENFORCE_EQ(input_data_type_index,
188-
ctx.Input<framework::Tensor>("Y")->type(),
189-
"The element's type of input should be the same.");
190-
PADDLE_ENFORCE_EQ(
191-
input_data_type_index,
192-
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
193-
"The element's type of input should be the same.");
194-
313+
// PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
314+
auto input_data_type_index = ctx.Input<framework::Tensor>("Y")->type();
195315
auto input_data_type = framework::ToDataType(input_data_type_index);
196316
return framework::OpKernelType(input_data_type, ctx.GetPlace());
197317
}

0 commit comments

Comments
 (0)