@@ -16,14 +16,10 @@ limitations under the License. */
16
16
#include < string>
17
17
#include " paddle/fluid/operators/math/blas.h"
18
18
#include " paddle/fluid/operators/math/cpu_vec.h"
19
- #include " paddle/fluid/operators/math/detail/activation_functions.h"
20
19
#include " paddle/fluid/operators/math/fc_compute.h"
21
- #include " paddle/fluid/operators/math/lstm_compute.h"
22
20
#include " paddle/fluid/operators/math/sequence2batch.h"
23
21
#include " paddle/fluid/platform/cpu_info.h"
24
22
25
- DEFINE_bool (seq_mode, true , " Use sequence mode" );
26
-
27
23
namespace paddle {
28
24
namespace operators {
29
25
@@ -110,7 +106,7 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
110
106
ctx->ShareLoD (" X" , " Cell" );
111
107
112
108
int xx_width;
113
- if (FLAGS_seq_mode ) {
109
+ if (ctx-> Attrs (). Get < bool >( " use_seq " ) ) {
114
110
xx_width = wx_dims[1 ];
115
111
} else {
116
112
xx_width = x_dims[1 ] > wx_dims[1 ] ? wx_dims[1 ] : x_dims[1 ];
@@ -189,6 +185,10 @@ void FusionLSTMOpMaker::Make() {
189
185
" (bool, defalut: False) "
190
186
" whether to compute reversed LSTM." )
191
187
.SetDefault (false );
188
+ AddAttr<bool >(" use_seq" ,
189
+ " (bool, defalut: True) "
190
+ " whether to use seq mode to compute." )
191
+ .SetDefault (true );
192
192
AddAttr<std::string>(" gate_activation" ,
193
193
" (string, default: sigmoid)"
194
194
" The activation for input gate, forget gate and output "
@@ -264,8 +264,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
264
264
const int N = x_lod[0 ].size () - 1 ; // batch size
265
265
266
266
const T* x_data = x->data <T>();
267
- const T* h0_data = h0 ? h0->data <T>() : NULL ;
268
- const T* c0_data = c0 ? c0->data <T>() : NULL ;
267
+ const T* h0_data = h0 ? h0->data <T>() : nullptr ;
268
+ const T* c0_data = c0 ? c0->data <T>() : nullptr ;
269
269
const T* wx_data = wx->data <T>();
270
270
const T* wh_data = wh->data <T>();
271
271
T* xx_data = xx->mutable_data <T>(ctx.GetPlace ());
@@ -295,8 +295,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
295
295
for (int i = 0 ; i < N; ++i) {
296
296
int bid = is_reverse ? N - 1 - i : i;
297
297
int seq_len = x_lod[0 ][bid + 1 ] - x_lod[0 ][bid];
298
- const T* prev_c_data = NULL ;
299
- const T* prev_h_data = NULL ;
298
+ const T* prev_c_data = nullptr ;
299
+ const T* prev_h_data = nullptr ;
300
300
int tstart = 0 ;
301
301
if (h0_data) {
302
302
prev_h_data = h0_data + bid * D;
@@ -351,8 +351,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
351
351
void BatchCompute (const framework::ExecutionContext& ctx) const {
352
352
using DeviceContext = platform::CPUDeviceContext;
353
353
INIT_BASE_INPUT_OUTPUT
354
- if (x->lod ()[0 ].size () == 2 ) { // batch size == 1
354
+ if (x->lod ()[0 ].size () == 2 ) {
355
355
SeqCompute (ctx);
356
+ return ;
356
357
}
357
358
INIT_BASE_SIZES
358
359
INIT_VEC_FUNC
@@ -396,8 +397,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
396
397
reordered_c0->Resize ({max_bs, D});
397
398
398
399
int tstart = 0 ;
399
- T* prev_h_data = NULL ;
400
- T* prev_c_data = NULL ;
400
+ T* prev_h_data = nullptr ;
401
+ T* prev_c_data = nullptr ;
401
402
if (h0) {
402
403
// reorder h0, c0
403
404
T* reordered_h0_data = reordered_h0->mutable_data <T>(place);
@@ -489,7 +490,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
489
490
}
490
491
491
492
void Compute (const framework::ExecutionContext& ctx) const override {
492
- if (FLAGS_seq_mode ) {
493
+ if (ctx. Attr < bool >( " use_seq " ) ) {
493
494
SeqCompute (ctx);
494
495
} else {
495
496
BatchCompute (ctx);
0 commit comments