@@ -21,8 +21,6 @@ limitations under the License. */
21
21
#include " paddle/fluid/operators/math/sequence2batch.h"
22
22
#include " paddle/fluid/platform/cpu_info.h"
23
23
24
- DEFINE_bool (gru_use_seq, true , " Use sequence mode" );
25
-
26
24
namespace paddle {
27
25
namespace operators {
28
26
@@ -87,7 +85,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
87
85
ctx->ShareLoD (" X" , " Hidden" );
88
86
89
87
int xx_width;
90
- if (FLAGS_gru_use_seq ) {
88
+ if (ctx-> Attrs (). Get < bool >( " use_seq " ) ) {
91
89
xx_width = wx_dims[1 ];
92
90
} else {
93
91
xx_width = x_dims[1 ] > wx_dims[1 ] ? wx_dims[1 ] : x_dims[1 ];
@@ -136,7 +134,10 @@ void FusionGRUOpMaker::Make() {
136
134
" where T is the total time steps in this mini-batch,"
137
135
" D is the hidden size, M is the dim size of x input." )
138
136
.AsIntermediate ();
139
- AddOutput (" BatchedInput" , " (LoDTensor) (T x 3D)" ).AsIntermediate ();
137
+ AddOutput (" BatchedInput" ,
138
+ " (LoDTensor) This is the batched result of input X"
139
+ " or the batched result after fc, shape (T x 3D)" )
140
+ .AsIntermediate ();
140
141
AddOutput (" BatchedOut" , " (LoDTensor) (T X D) save batched hidden." )
141
142
.AsIntermediate ();
142
143
AddOutput (" Hidden" , " (LoDTensor) (T x D) Same as GRUOp" );
@@ -153,6 +154,10 @@ void FusionGRUOpMaker::Make() {
153
154
" (bool, defalut: False) "
154
155
" whether to compute reversed GRU." )
155
156
.SetDefault (false );
157
+ AddAttr<bool >(" use_seq" ,
158
+ " (bool, defalut: True) "
159
+ " whether to use seq mode to compute GRU." )
160
+ .SetDefault (true );
156
161
AddComment (R"DOC(
157
162
The Fusion complete GRU Operator.
158
163
This operator fuse the fully-connected operator into GRU,
@@ -164,7 +169,7 @@ template <typename T>
164
169
class FusionGRUKernel : public framework ::OpKernel<T> {
165
170
public:
166
171
void Compute (const framework::ExecutionContext& ctx) const override {
167
- if (FLAGS_gru_use_seq ) {
172
+ if (ctx. Attr < bool >( " use_seq " ) ) {
168
173
SeqCompute (ctx);
169
174
} else {
170
175
BatchCompute (ctx);
@@ -188,31 +193,35 @@ class FusionGRUKernel : public framework::OpKernel<T> {
188
193
cross = math::vec_cross<T, platform::jit::isa_any>; \
189
194
}
190
195
196
+ #define INIT_BASE_INPUT_OUTPUT \
197
+ auto * h0 = ctx.Input<Tensor>(" H0" ); \
198
+ auto * wx = ctx.Input<Tensor>(" WeightX" ); \
199
+ auto * wh = ctx.Input<Tensor>(" WeightH" ); \
200
+ auto * bias = ctx.Input<Tensor>(" Bias" ); \
201
+ auto * xx = ctx.Output<LoDTensor>(" XX" ); \
202
+ auto * hidden_out = ctx.Output<LoDTensor>(" Hidden" ); \
203
+ bool is_reverse = ctx.Attr<bool >(" is_reverse" );
204
+
205
+ #define INIT_BASE_SIZES \
206
+ auto x_dims = x->dims (); /* T x M*/ \
207
+ auto wh_dims = wh->dims (); /* D x 3D*/ \
208
+ const int total_T = x_dims[0 ]; \
209
+ const int M = x_dims[1 ]; \
210
+ const int D = wh_dims[0 ]; \
211
+ const int D3 = wh_dims[1 ]; \
212
+ const int D2 = D * 2 ;
213
+
191
214
void SeqCompute (const framework::ExecutionContext& ctx) const {
192
215
using DeviceContext = paddle::platform::CPUDeviceContext;
193
216
auto * x = ctx.Input <LoDTensor>(" X" );
194
- auto * h0 = ctx.Input <Tensor>(" H0" );
195
- auto * wx = ctx.Input <Tensor>(" WeightX" );
196
- auto * wh = ctx.Input <Tensor>(" WeightH" );
197
- auto * bias = ctx.Input <Tensor>(" Bias" );
198
-
199
- auto * xx = ctx.Output <LoDTensor>(" XX" );
200
- auto * hidden_out = ctx.Output <LoDTensor>(" Hidden" );
201
- bool is_reverse = ctx.Attr <bool >(" is_reverse" );
217
+ INIT_BASE_INPUT_OUTPUT
218
+ INIT_BASE_SIZES
202
219
INIT_VEC_FUNC
203
220
204
221
auto x_lod = x->lod ();
205
- auto x_dims = x->dims (); // T x M
206
- auto wh_dims = wh->dims (); // D x 3D
207
222
const int N = x_lod[0 ].size () - 1 ;
208
- const int total_T = x_dims[0 ];
209
- const int M = x_dims[1 ];
210
- const int D3 = wh_dims[1 ];
211
- const int D = wh_dims[0 ];
212
- const int D2 = D * 2 ;
213
-
214
223
const T* x_data = x->data <T>();
215
- const T* h0_data = h0 ? h0->data <T>() : NULL ;
224
+ const T* h0_data = h0 ? h0->data <T>() : nullptr ;
216
225
const T* wx_data = wx->data <T>();
217
226
const T* wh_data = wh->data <T>();
218
227
const T* wh_state_data = wh_data + D * D2;
@@ -221,7 +230,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
221
230
222
231
auto blas = math::GetBlas<DeviceContext, T>(ctx);
223
232
math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
224
- xx_data, bias ? bias->data <T>() : NULL );
233
+ xx_data,
234
+ bias ? bias->data <T>() : nullptr );
225
235
226
236
int xx_offset = D3;
227
237
int gate_offset = D;
@@ -239,7 +249,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
239
249
for (int i = 0 ; i < N; ++i) {
240
250
int bid = is_reverse ? N - 1 - i : i;
241
251
int seq_len = x_lod[0 ][bid + 1 ] - x_lod[0 ][bid];
242
- const T* prev_hidden_data = NULL ;
252
+ const T* prev_hidden_data = nullptr ;
243
253
int tstart = 0 ;
244
254
if (h0_data) {
245
255
prev_hidden_data = h0_data + bid * D;
@@ -282,19 +292,17 @@ class FusionGRUKernel : public framework::OpKernel<T> {
282
292
void BatchCompute (const framework::ExecutionContext& ctx) const {
283
293
using DeviceContext = paddle::platform::CPUDeviceContext;
284
294
auto * x = ctx.Input <LoDTensor>(" X" );
285
- auto * wx = ctx.Input <Tensor>(" WeightX" );
286
- auto * wh = ctx.Input <Tensor>(" WeightH" );
287
- auto * bias = ctx.Input <Tensor>(" Bias" );
288
- auto * h0 = ctx.Input <Tensor>(" H0" );
295
+ if (x->lod ()[0 ].size () == 2 ) {
296
+ SeqCompute (ctx);
297
+ return ;
298
+ }
299
+ INIT_BASE_INPUT_OUTPUT
300
+ INIT_BASE_SIZES
301
+ INIT_VEC_FUNC
289
302
290
303
auto * reordered_h0 = ctx.Output <Tensor>(" ReorderedH0" );
291
- auto * xx = ctx.Output <LoDTensor>(" XX" );
292
304
auto * batched_input = ctx.Output <LoDTensor>(" BatchedInput" );
293
305
auto * batched_out = ctx.Output <LoDTensor>(" BatchedOut" );
294
- auto * hidden_out = ctx.Output <LoDTensor>(" Hidden" );
295
-
296
- bool is_reverse = ctx.Attr <bool >(" is_reverse" );
297
- INIT_VEC_FUNC
298
306
299
307
const T* x_data = x->data <T>();
300
308
const T* wx_data = wx->data <T>();
@@ -304,25 +312,20 @@ class FusionGRUKernel : public framework::OpKernel<T> {
304
312
T* batched_out_data = batched_out->mutable_data <T>(ctx.GetPlace ());
305
313
hidden_out->mutable_data <T>(ctx.GetPlace ());
306
314
307
- auto x_dims = x->dims ();
308
- auto wx_dims = wx->dims ();
309
- const int D3 = wx_dims[1 ];
310
- const int D = D3 / 3 ;
311
- const int D2 = D * 2 ;
312
315
auto & dev_ctx = ctx.template device_context <DeviceContext>();
313
316
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
314
317
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
315
- if (x_dims[ 1 ] > wx_dims[ 1 ] ) {
316
- math::FCCompute<DeviceContext, T>(blas, x_dims[ 0 ], wx_dims[ 1 ], x_dims[ 1 ] ,
317
- x_data, wx_data, xx_data,
318
- bias ? bias->data <T>() : NULL );
318
+ if (M > D3 ) {
319
+ math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data ,
320
+ xx_data,
321
+ bias ? bias->data <T>() : nullptr );
319
322
to_batch (dev_ctx, *xx, batched_input, true , is_reverse);
320
323
} else {
321
324
to_batch (dev_ctx, *x, xx, true , is_reverse);
322
325
batched_input->set_lod (xx->lod ());
323
- math::FCCompute<DeviceContext, T>(blas, x_dims[ 0 ], wx_dims[ 1 ], x_dims[ 1 ] ,
324
- xx_data, wx_data, batched_input_data,
325
- bias ? bias->data <T>() : NULL );
326
+ math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, xx_data, wx_data ,
327
+ batched_input_data,
328
+ bias ? bias->data <T>() : nullptr );
326
329
}
327
330
328
331
auto batched_lod = batched_input->lod ();
@@ -331,7 +334,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
331
334
reordered_h0->Resize ({max_bs, D});
332
335
333
336
int tstart = 0 ;
334
- T* prev_hidden_data = NULL ;
337
+ T* prev_hidden_data = nullptr ;
335
338
if (h0) {
336
339
// reorder h0
337
340
T* reordered_h0_data = reordered_h0->mutable_data <T>(ctx.GetPlace ());
@@ -415,6 +418,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
415
418
to_seq (dev_ctx, *batched_out, hidden_out);
416
419
}
417
420
#undef INIT_VEC_FUNC
421
+ #undef INIT_BASE_SIZES
422
+ #undef INIT_BASE_INPUT_OUTPUT
418
423
};
419
424
420
425
} // namespace operators
0 commit comments