Skip to content

Commit 83f4bc4

Browse files
committed
follow comment and refine code
1 parent 9838bac commit 83f4bc4

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
16+
#include <string>
1617
#include "paddle/fluid/framework/lod_tensor.h"
1718

1819
namespace paddle {
@@ -97,6 +98,8 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
9798
op_desc.SetOutput("BatchedInput", {"blstm_0.tmp_2"});
9899
op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse"));
99100
op_desc.SetAttr("use_peepholes", lstm_n->Op()->GetAttr("use_peepholes"));
101+
// TODO(TJ): get from attr
102+
op_desc.SetAttr("use_seq", true);
100103

101104
#define TMP_NAME(x) "at.new.tmp." #x
102105
#define OP_SET_OUT(x) op_desc.SetOutput(#x, {TMP_NAME(x)})
@@ -134,7 +137,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
134137

135138
auto fc_no_bias_handler = [&](
136139
const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
137-
138140
#define GET_NODE(name__) \
139141
std::string name__##key = name_scope + "/" + #name__; \
140142
auto* name__##n = pattern->RetrieveNode(name__##key); \

paddle/fluid/operators/fusion_lstm_op.cc

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,10 @@ limitations under the License. */
1616
#include <string>
1717
#include "paddle/fluid/operators/math/blas.h"
1818
#include "paddle/fluid/operators/math/cpu_vec.h"
19-
#include "paddle/fluid/operators/math/detail/activation_functions.h"
2019
#include "paddle/fluid/operators/math/fc_compute.h"
21-
#include "paddle/fluid/operators/math/lstm_compute.h"
2220
#include "paddle/fluid/operators/math/sequence2batch.h"
2321
#include "paddle/fluid/platform/cpu_info.h"
2422

25-
DEFINE_bool(seq_mode, true, "Use sequence mode");
26-
2723
namespace paddle {
2824
namespace operators {
2925

@@ -110,7 +106,7 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
110106
ctx->ShareLoD("X", "Cell");
111107

112108
int xx_width;
113-
if (FLAGS_seq_mode) {
109+
if (ctx->Attrs().Get<bool>("use_seq")) {
114110
xx_width = wx_dims[1];
115111
} else {
116112
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
@@ -189,6 +185,10 @@ void FusionLSTMOpMaker::Make() {
189185
"(bool, defalut: False) "
190186
"whether to compute reversed LSTM.")
191187
.SetDefault(false);
188+
AddAttr<bool>("use_seq",
189+
"(bool, defalut: True) "
190+
"whether to use seq mode to compute.")
191+
.SetDefault(true);
192192
AddAttr<std::string>("gate_activation",
193193
"(string, default: sigmoid)"
194194
"The activation for input gate, forget gate and output "
@@ -264,8 +264,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
264264
const int N = x_lod[0].size() - 1; // batch size
265265

266266
const T* x_data = x->data<T>();
267-
const T* h0_data = h0 ? h0->data<T>() : NULL;
268-
const T* c0_data = c0 ? c0->data<T>() : NULL;
267+
const T* h0_data = h0 ? h0->data<T>() : nullptr;
268+
const T* c0_data = c0 ? c0->data<T>() : nullptr;
269269
const T* wx_data = wx->data<T>();
270270
const T* wh_data = wh->data<T>();
271271
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
@@ -295,8 +295,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
295295
for (int i = 0; i < N; ++i) {
296296
int bid = is_reverse ? N - 1 - i : i;
297297
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
298-
const T* prev_c_data = NULL;
299-
const T* prev_h_data = NULL;
298+
const T* prev_c_data = nullptr;
299+
const T* prev_h_data = nullptr;
300300
int tstart = 0;
301301
if (h0_data) {
302302
prev_h_data = h0_data + bid * D;
@@ -351,8 +351,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
351351
void BatchCompute(const framework::ExecutionContext& ctx) const {
352352
using DeviceContext = platform::CPUDeviceContext;
353353
INIT_BASE_INPUT_OUTPUT
354-
if (x->lod()[0].size() == 2) { // batch size == 1
354+
if (x->lod()[0].size() == 2) {
355355
SeqCompute(ctx);
356+
return;
356357
}
357358
INIT_BASE_SIZES
358359
INIT_VEC_FUNC
@@ -396,8 +397,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
396397
reordered_c0->Resize({max_bs, D});
397398

398399
int tstart = 0;
399-
T* prev_h_data = NULL;
400-
T* prev_c_data = NULL;
400+
T* prev_h_data = nullptr;
401+
T* prev_c_data = nullptr;
401402
if (h0) {
402403
// reorder h0, c0
403404
T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
@@ -489,7 +490,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
489490
}
490491

491492
void Compute(const framework::ExecutionContext& ctx) const override {
492-
if (FLAGS_seq_mode) {
493+
if (ctx.Attr<bool>("use_seq")) {
493494
SeqCompute(ctx);
494495
} else {
495496
BatchCompute(ctx);

0 commit comments

Comments
 (0)