@@ -12,14 +12,60 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
+ #include " paddle/fluid/operators/fused_elemwise_activation_op.h"
15
16
#include < string>
16
17
#include < vector>
17
18
18
- #include " paddle/fluid/operators/fused_elemwise_activation_op.h"
19
-
20
19
namespace paddle {
21
20
namespace operators {
22
21
22
+ /*
23
+ * Whether the compound function is Unary(Binary(X, Y)).
24
+ * For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final
25
+ * out.
26
+ */
27
+ static bool IsUnaryCompound (const std::vector<std::string> &functor_list) {
28
+ PADDLE_ENFORCE_EQ (functor_list.size (), 2 );
29
+ static std::unordered_set<std::string> binary_fun = {
30
+ " elementwise_add" , " elementwise_mul" , " elementwise_add_grad" ,
31
+ " elementwise_mul_grad" };
32
+ return binary_fun.count (functor_list[1 ]) != 0 ;
33
+ }
34
+
35
+ /*
36
+ * Whether the Input(X) could be absent.
37
+ */
38
+ static bool InputXCanBeAbsent (const std::vector<std::string> &functor_list) {
39
+ PADDLE_ENFORCE_EQ (functor_list.size (), 2 );
40
+ static std::unordered_set<std::string> binary_fun = {" elementwise_add_grad" };
41
+ return binary_fun.count (functor_list[0 ]) != 0 ||
42
+ binary_fun.count (functor_list[1 ]) != 0 ;
43
+ }
44
+
45
+ /*
46
+ * Whether the compound function is supported.
47
+ * For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final
48
+ * out.
49
+ */
50
+ static bool IsSupportedCompound (const std::vector<std::string> &functors) {
51
+ static std::unordered_set<std::string> unary_fun = {" scale" , " relu" };
52
+ static std::unordered_set<std::string> binary_fun = {" elementwise_add" ,
53
+ " elementwise_mul" };
54
+
55
+ std::string unary_fun_str;
56
+ if (binary_fun.count (functors[0 ])) {
57
+ unary_fun_str = functors[1 ];
58
+ } else if (binary_fun.count (functors[1 ])) {
59
+ unary_fun_str = functors[0 ];
60
+ } else {
61
+ PADDLE_THROW (" %s and %s are not included in fused_list." , functors[0 ],
62
+ functors[1 ]);
63
+ }
64
+ PADDLE_ENFORCE_EQ (unary_fun.count (unary_fun_str), 1 ,
65
+ " %s is not included in fused_list." , unary_fun_str);
66
+ return true ;
67
+ }
68
+
23
69
class FusedElemwiseActivationOp : public framework ::OperatorWithKernel {
24
70
public:
25
71
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -37,11 +83,44 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
37
83
38
84
auto x_dim = ctx->GetInputDim (" X" );
39
85
auto y_dim = ctx->GetInputDim (" Y" );
40
- PADDLE_ENFORCE_GE (x_dim.size (), y_dim.size (),
41
- " Rank of first input must >= rank of second input." );
42
86
43
- ctx->SetOutputDim (" Out" , x_dim);
44
- ctx->ShareLoD (" X" , /* ->*/ " Out" );
87
+ // Whether the shape of Y is a continuous subsequence of X,
88
+ // For more information please refer to the op's introduction.
89
+ bool bcast_y = x_dim.size () >= y_dim.size ();
90
+ if (x_dim.size () == y_dim.size ()) {
91
+ for (int i = 0 ; i < x_dim.size (); ++i) {
92
+ if (x_dim[i] < y_dim[i]) {
93
+ bcast_y = false ;
94
+ break ;
95
+ }
96
+ }
97
+ }
98
+
99
+ auto &out_dim = bcast_y ? x_dim : y_dim;
100
+ std::string out_lod = bcast_y ? " X" : " Y" ;
101
+
102
+ if (ctx->Attrs ().Get <bool >(" keep_intermediate_value" )) {
103
+ PADDLE_ENFORCE (ctx->HasOutput (" IntermediateOut" ),
104
+ " Output(IntermediateOut) of FusedElemwiseActivationOp "
105
+ " should not be null." );
106
+
107
+ if (IsUnaryCompound (
108
+ ctx->Attrs ().Get <std::vector<std::string>>(" functor_list" ))) {
109
+ // for Unary(Binary(X, Y)), the shape and lod of out and
110
+ // intermediate_out are the same.
111
+ ctx->SetOutputDim (" IntermediateOut" , out_dim);
112
+ // set the lod of intermediate_out
113
+ ctx->ShareLoD (out_lod, /* ->*/ " IntermediateOut" );
114
+ } else {
115
+ // for Binary(X, Unary(Y)), the shape and lod of Y and
116
+ // intermediate_out are the same.
117
+ ctx->SetOutputDim (" IntermediateOut" , y_dim);
118
+ // set the lod of intermediate_out
119
+ ctx->ShareLoD (" Y" , /* ->*/ " IntermediateOut" );
120
+ }
121
+ }
122
+ ctx->SetOutputDim (" Out" , out_dim);
123
+ ctx->ShareLoD (out_lod, /* ->*/ " Out" );
45
124
}
46
125
47
126
protected:
@@ -59,29 +138,42 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
59
138
class FusedElemwiseActivationMaker : public framework ::OpProtoAndCheckerMaker {
60
139
public:
61
140
void Make () override {
62
- AddInput (" X" , " (vector<Tensor>)" );
63
- AddInput (" Y" , " (vector<Tensor>)" );
64
- AddOutput (" Out" , " vector<Tensor>" );
141
+ AddInput (
142
+ " X" ,
143
+ " (Tensor) The input tensor of fused_elemwise_activation operator." );
144
+ AddInput (
145
+ " Y" ,
146
+ " (Tensor) The input tensor of fused_elemwise_activation operator." );
147
+ AddOutput (" Out" ,
148
+ " vector<Tensor> The output tensor of fused_elemwise_activation "
149
+ " operator." );
150
+ AddOutput (" IntermediateOut" ,
151
+ " Tensor The IntermediateOut tensor of fused_elemwise_activation "
152
+ " operator." )
153
+ .AsIntermediate ();
65
154
AddAttr<int >(" axis" ,
66
155
" axis is used by elementwise_op, the default value is -1." )
67
156
.SetDefault (-1 );
68
157
AddAttr<float >(" scale" ,
69
158
" scale is used by scale_op, the default value is 0.0." )
70
159
.SetDefault (0.0 );
71
- AddAttr<bool >(" recomputation " ,
72
- " Whether to recompute the Out. "
73
- " fused_elemwise_activation_grad has two methods to get the "
74
- " dx and dy, one "
75
- " is to use the 'Out', and the other is not to use it . "
76
- " The former method will save the time of recomputing the "
77
- " 'Out', but it must occupy the memory to store the 'out'. "
78
- " While, the later method can avoid occupying the memory, "
79
- " but it must recompute the 'Out' . The default value is true." )
160
+ AddAttr<bool >(
161
+ " recomputation " ,
162
+ " Whether to recompute the Out. "
163
+ " The computation of fused_elemwise_activation_grad has two methods to "
164
+ " get the dx and dy, one is to use the 'Out', and the other is not. "
165
+ " The former method will save the time of recomputing the 'Out', but it "
166
+ " must occupy the memory to store the 'out'. While, the later method "
167
+ " can avoid occupying the memory, but it must recompute the 'Out'. "
168
+ " It is useful for Unary(Binary(X, Y)) . The default value is true." )
80
169
.SetDefault (true );
170
+ AddAttr<bool >(" keep_intermediate_value" ,
171
+ " Whether to save the intermediate_out." )
172
+ .SetDefault (false );
81
173
AddAttr<std::vector<std::string>>(" functor_list" ,
82
174
" The functors that should be fused." )
83
175
.AddCustomChecker ([&](const std::vector<std::string> &functor_list) {
84
- PADDLE_ENFORCE (ValidCheck (functor_list));
176
+ PADDLE_ENFORCE (IsSupportedCompound (functor_list));
85
177
});
86
178
87
179
AddComment (R"DOC(
@@ -93,30 +185,38 @@ operators (elementwise_op and activation_op):
93
185
Z = Binary(X, Unary(Y))
94
186
Z = Unary(Binary(X, Y))
95
187
96
- The attributions of activation_op can be get from fused_elemwise_activation_op's
97
- attributions. functor_list records the functors to be fused, for example
98
- "scale,elementwise_add".
188
+ There are two cases for this operator:
99
189
100
- )DOC" );
101
- }
190
+ 1. The shape of $Y$ and $X$ is the same.
191
+ 2. The shape of $Y$ is a continuous subsequence of $X$ or the shape of $X$ is a continuous subsequence of $Y$.
102
192
103
- private:
104
- bool ValidCheck (const std::vector<std::string> &functors) {
105
- std::unordered_set<std::string> unary_fun = {" scale" , " relu" };
106
- std::unordered_set<std::string> binary_fun = {" elementwise_add" };
193
+ For case 2 (assume that the shape of $Y$ is a continuous subsequence of $X$ ):
107
194
108
- std::string unary_fun_str;
109
- if (binary_fun.count (functors[0 ])) {
110
- unary_fun_str = functors[1 ];
111
- } else if (binary_fun.count (functors[1 ])) {
112
- unary_fun_str = functors[0 ];
113
- } else {
114
- PADDLE_THROW (" %s and %s are not included in fused_list." , functors[0 ],
115
- functors[1 ]);
116
- }
117
- PADDLE_ENFORCE_EQ (unary_fun.count (unary_fun_str), 1 ,
118
- " %s is not included in fused_list." , unary_fun_str);
119
- return true ;
195
+ 1. Broadcast $Y$ to match the shape of $X$, where $axis$ is the start dimension index
196
+ for broadcasting $Y$ onto $X$.
197
+ 2. If $axis$ is -1 (default), $axis = rank(X) - rank(Y)$.
198
+ 3. The trailing dimensions of size 1 for $Y$ will be ignored for the consideration of
199
+ subsequence, such as shape(Y) = (2, 1) => (2).
200
+
201
+ For example:
202
+
203
+ .. code-block:: python
204
+
205
+ shape(X) = (2, 3, 4, 5), shape(Y) = (,)
206
+ shape(X) = (2, 3, 4, 5), shape(Y) = (5,)
207
+ shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5), with axis=-1(default) or axis=2
208
+ shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
209
+ shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
210
+ shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0
211
+
212
+
213
+ The inputs $X$ and $Y$ can carry the different LoD information.
214
+ But the output only shares the LoD information with the one whose shape is the same with Out.
215
+ The attributions of activation_op can be get from fused_elemwise_activation_op's.
216
+ The functor_list records the functions to be fused, for example
217
+ ["scale", "elementwise_add"].
218
+
219
+ )DOC" );
120
220
}
121
221
};
122
222
@@ -141,6 +241,7 @@ class FusedElemwiseActivationGradMaker
141
241
op_desc_ptr->SetInput (framework::GradVarName (output_param),
142
242
this ->OutputGrad (output_param));
143
243
}
244
+
144
245
op_desc_ptr->SetAttrMap (this ->Attrs ());
145
246
146
247
std::vector<std::string> functor_names =
@@ -158,40 +259,59 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
158
259
using framework::OperatorWithKernel::OperatorWithKernel;
159
260
160
261
void InferShape (framework::InferShapeContext *ctx) const override {
161
- PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input(X) should not be null" );
162
- PADDLE_ENFORCE (ctx->HasInput (" Y" ), " Input(Y) should not be null" );
163
262
PADDLE_ENFORCE (ctx->HasInput (framework::GradVarName (" Out" )),
164
- " Input(Out@GRAD) should not be null" );
165
-
166
- auto x_dims = ctx->GetInputDim (" X" );
167
- auto y_dims = ctx->GetInputDim (" Y" );
168
- auto out_dims = ctx->GetInputDim (framework::GradVarName (" Out" ));
169
-
170
- PADDLE_ENFORCE_GE (x_dims.size (), y_dims.size (),
171
- " Rank of first input must >= rank of second input." );
263
+ " Input(Out@Grad) should not be null" );
264
+ if (ctx->Attrs ().Get <bool >(" keep_intermediate_value" )) {
265
+ PADDLE_ENFORCE (ctx->HasInput (" IntermediateOut" ),
266
+ " Input(IntermediateOut) should not be null" );
267
+ } else {
268
+ PADDLE_ENFORCE_EQ (ctx->Inputs (framework::GradVarName (" Out" )).size (), 1 );
269
+ }
172
270
271
+ auto funtor_list =
272
+ ctx->Attrs ().Get <std::vector<std::string>>(" functor_list" );
173
273
auto x_grad_name = framework::GradVarName (" X" );
174
274
auto y_grad_name = framework::GradVarName (" Y" );
275
+
175
276
if (ctx->HasOutput (x_grad_name)) {
176
- ctx->SetOutputDim (x_grad_name, x_dims);
277
+ if (ctx->HasInputs (" X" )) {
278
+ ctx->SetOutputDim (x_grad_name, ctx->GetInputDim (" X" ));
279
+ ctx->ShareLoD (" X" , x_grad_name);
280
+ } else {
281
+ // Node: If "X" is absence, the shape of Y should be a continuous
282
+ // subsequence of X, if not, we could not infer the shape of dx.
283
+
284
+ // Currently, only when Binary is elementwise_add or elementwise_sub,
285
+ // the "X" could be absent.
286
+ PADDLE_ENFORCE (InputXCanBeAbsent (funtor_list),
287
+ " Only when BinaryFunctor is elementwise_add, the 'X' "
288
+ " could be absent." );
289
+
290
+ // For Unary(Binary(X, Y)), IntermediateOut should not be empty.
291
+ if (IsUnaryCompound (funtor_list)) {
292
+ PADDLE_ENFORCE (
293
+ ctx->HasInputs (" IntermediateOut" ),
294
+ " If the compound_functor is Unary(Binary(X, Y)) and Binary "
295
+ " is elementwise_add, the intermediate_out must be not absent." );
296
+ }
297
+
298
+ ctx->SetOutputDim (x_grad_name,
299
+ ctx->GetInputDim (framework::GradVarName (" Out" )));
300
+ ctx->ShareLoD (framework::GradVarName (" Out" ), x_grad_name);
301
+ }
177
302
}
178
303
if (ctx->HasOutput (y_grad_name)) {
179
- ctx->SetOutputDim (y_grad_name, y_dims);
304
+ PADDLE_ENFORCE (ctx->HasInput (" Y" ), " Input(Y) should not be null" );
305
+ ctx->SetOutputDim (y_grad_name, ctx->GetInputDim (" Y" ));
306
+ ctx->ShareLoD (" Y" , y_grad_name);
180
307
}
181
308
}
182
309
183
310
protected:
184
311
framework::OpKernelType GetExpectedKernelType (
185
312
const framework::ExecutionContext &ctx) const override {
186
- auto input_data_type_index = ctx.Input <framework::Tensor>(" X" )->type ();
187
- PADDLE_ENFORCE_EQ (input_data_type_index,
188
- ctx.Input <framework::Tensor>(" Y" )->type (),
189
- " The element's type of input should be the same." );
190
- PADDLE_ENFORCE_EQ (
191
- input_data_type_index,
192
- ctx.Input <framework::Tensor>(framework::GradVarName (" Out" ))->type (),
193
- " The element's type of input should be the same." );
194
-
313
+ // PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
314
+ auto input_data_type_index = ctx.Input <framework::Tensor>(" Y" )->type ();
195
315
auto input_data_type = framework::ToDataType (input_data_type_index);
196
316
return framework::OpKernelType (input_data_type, ctx.GetPlace ());
197
317
}
0 commit comments