Skip to content

Commit c7adb99

Browse files
committed
follow comment and refine code
1 parent f38905a commit c7adb99

File tree

1 file changed

+51
-46
lines changed

1 file changed

+51
-46
lines changed

paddle/fluid/operators/fusion_gru_op.cc

Lines changed: 51 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ limitations under the License. */
2121
#include "paddle/fluid/operators/math/sequence2batch.h"
2222
#include "paddle/fluid/platform/cpu_info.h"
2323

24-
DEFINE_bool(gru_use_seq, true, "Use sequence mode");
25-
2624
namespace paddle {
2725
namespace operators {
2826

@@ -87,7 +85,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
8785
ctx->ShareLoD("X", "Hidden");
8886

8987
int xx_width;
90-
if (FLAGS_gru_use_seq) {
88+
if (ctx->Attrs().Get<bool>("use_seq")) {
9189
xx_width = wx_dims[1];
9290
} else {
9391
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
@@ -136,7 +134,10 @@ void FusionGRUOpMaker::Make() {
136134
" where T is the total time steps in this mini-batch,"
137135
" D is the hidden size, M is the dim size of x input.")
138136
.AsIntermediate();
139-
AddOutput("BatchedInput", "(LoDTensor) (T x 3D)").AsIntermediate();
137+
AddOutput("BatchedInput",
138+
"(LoDTensor) This is the batched result of input X"
139+
"or the batched result after fc, shape (T x 3D)")
140+
.AsIntermediate();
140141
AddOutput("BatchedOut", "(LoDTensor) (T X D) save batched hidden.")
141142
.AsIntermediate();
142143
AddOutput("Hidden", "(LoDTensor) (T x D) Same as GRUOp");
@@ -153,6 +154,10 @@ void FusionGRUOpMaker::Make() {
153154
"(bool, defalut: False) "
154155
"whether to compute reversed GRU.")
155156
.SetDefault(false);
157+
AddAttr<bool>("use_seq",
158+
"(bool, defalut: True) "
159+
"whether to use seq mode to compute GRU.")
160+
.SetDefault(true);
156161
AddComment(R"DOC(
157162
The Fusion complete GRU Operator.
158163
This operator fuse the fully-connected operator into GRU,
@@ -164,7 +169,7 @@ template <typename T>
164169
class FusionGRUKernel : public framework::OpKernel<T> {
165170
public:
166171
void Compute(const framework::ExecutionContext& ctx) const override {
167-
if (FLAGS_gru_use_seq) {
172+
if (ctx.Attr<bool>("use_seq")) {
168173
SeqCompute(ctx);
169174
} else {
170175
BatchCompute(ctx);
@@ -188,31 +193,35 @@ class FusionGRUKernel : public framework::OpKernel<T> {
188193
cross = math::vec_cross<T, platform::jit::isa_any>; \
189194
}
190195

196+
#define INIT_BASE_INPUT_OUTPUT \
197+
auto* h0 = ctx.Input<Tensor>("H0"); \
198+
auto* wx = ctx.Input<Tensor>("WeightX"); \
199+
auto* wh = ctx.Input<Tensor>("WeightH"); \
200+
auto* bias = ctx.Input<Tensor>("Bias"); \
201+
auto* xx = ctx.Output<LoDTensor>("XX"); \
202+
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
203+
bool is_reverse = ctx.Attr<bool>("is_reverse");
204+
205+
#define INIT_BASE_SIZES \
206+
auto x_dims = x->dims(); /* T x M*/ \
207+
auto wh_dims = wh->dims(); /* D x 3D*/ \
208+
const int total_T = x_dims[0]; \
209+
const int M = x_dims[1]; \
210+
const int D = wh_dims[0]; \
211+
const int D3 = wh_dims[1]; \
212+
const int D2 = D * 2;
213+
191214
void SeqCompute(const framework::ExecutionContext& ctx) const {
192215
using DeviceContext = paddle::platform::CPUDeviceContext;
193216
auto* x = ctx.Input<LoDTensor>("X");
194-
auto* h0 = ctx.Input<Tensor>("H0");
195-
auto* wx = ctx.Input<Tensor>("WeightX");
196-
auto* wh = ctx.Input<Tensor>("WeightH");
197-
auto* bias = ctx.Input<Tensor>("Bias");
198-
199-
auto* xx = ctx.Output<LoDTensor>("XX");
200-
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
201-
bool is_reverse = ctx.Attr<bool>("is_reverse");
217+
INIT_BASE_INPUT_OUTPUT
218+
INIT_BASE_SIZES
202219
INIT_VEC_FUNC
203220

204221
auto x_lod = x->lod();
205-
auto x_dims = x->dims(); // T x M
206-
auto wh_dims = wh->dims(); // D x 3D
207222
const int N = x_lod[0].size() - 1;
208-
const int total_T = x_dims[0];
209-
const int M = x_dims[1];
210-
const int D3 = wh_dims[1];
211-
const int D = wh_dims[0];
212-
const int D2 = D * 2;
213-
214223
const T* x_data = x->data<T>();
215-
const T* h0_data = h0 ? h0->data<T>() : NULL;
224+
const T* h0_data = h0 ? h0->data<T>() : nullptr;
216225
const T* wx_data = wx->data<T>();
217226
const T* wh_data = wh->data<T>();
218227
const T* wh_state_data = wh_data + D * D2;
@@ -221,7 +230,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
221230

222231
auto blas = math::GetBlas<DeviceContext, T>(ctx);
223232
math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
224-
xx_data, bias ? bias->data<T>() : NULL);
233+
xx_data,
234+
bias ? bias->data<T>() : nullptr);
225235

226236
int xx_offset = D3;
227237
int gate_offset = D;
@@ -239,7 +249,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
239249
for (int i = 0; i < N; ++i) {
240250
int bid = is_reverse ? N - 1 - i : i;
241251
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
242-
const T* prev_hidden_data = NULL;
252+
const T* prev_hidden_data = nullptr;
243253
int tstart = 0;
244254
if (h0_data) {
245255
prev_hidden_data = h0_data + bid * D;
@@ -282,19 +292,17 @@ class FusionGRUKernel : public framework::OpKernel<T> {
282292
void BatchCompute(const framework::ExecutionContext& ctx) const {
283293
using DeviceContext = paddle::platform::CPUDeviceContext;
284294
auto* x = ctx.Input<LoDTensor>("X");
285-
auto* wx = ctx.Input<Tensor>("WeightX");
286-
auto* wh = ctx.Input<Tensor>("WeightH");
287-
auto* bias = ctx.Input<Tensor>("Bias");
288-
auto* h0 = ctx.Input<Tensor>("H0");
295+
if (x->lod()[0].size() == 2) {
296+
SeqCompute(ctx);
297+
return;
298+
}
299+
INIT_BASE_INPUT_OUTPUT
300+
INIT_BASE_SIZES
301+
INIT_VEC_FUNC
289302

290303
auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
291-
auto* xx = ctx.Output<LoDTensor>("XX");
292304
auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
293305
auto* batched_out = ctx.Output<LoDTensor>("BatchedOut");
294-
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
295-
296-
bool is_reverse = ctx.Attr<bool>("is_reverse");
297-
INIT_VEC_FUNC
298306

299307
const T* x_data = x->data<T>();
300308
const T* wx_data = wx->data<T>();
@@ -304,25 +312,20 @@ class FusionGRUKernel : public framework::OpKernel<T> {
304312
T* batched_out_data = batched_out->mutable_data<T>(ctx.GetPlace());
305313
hidden_out->mutable_data<T>(ctx.GetPlace());
306314

307-
auto x_dims = x->dims();
308-
auto wx_dims = wx->dims();
309-
const int D3 = wx_dims[1];
310-
const int D = D3 / 3;
311-
const int D2 = D * 2;
312315
auto& dev_ctx = ctx.template device_context<DeviceContext>();
313316
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
314317
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
315-
if (x_dims[1] > wx_dims[1]) {
316-
math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1],
317-
x_data, wx_data, xx_data,
318-
bias ? bias->data<T>() : NULL);
318+
if (M > D3) {
319+
math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
320+
xx_data,
321+
bias ? bias->data<T>() : nullptr);
319322
to_batch(dev_ctx, *xx, batched_input, true, is_reverse);
320323
} else {
321324
to_batch(dev_ctx, *x, xx, true, is_reverse);
322325
batched_input->set_lod(xx->lod());
323-
math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1],
324-
xx_data, wx_data, batched_input_data,
325-
bias ? bias->data<T>() : NULL);
326+
math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, xx_data, wx_data,
327+
batched_input_data,
328+
bias ? bias->data<T>() : nullptr);
326329
}
327330

328331
auto batched_lod = batched_input->lod();
@@ -331,7 +334,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
331334
reordered_h0->Resize({max_bs, D});
332335

333336
int tstart = 0;
334-
T* prev_hidden_data = NULL;
337+
T* prev_hidden_data = nullptr;
335338
if (h0) {
336339
// reorder h0
337340
T* reordered_h0_data = reordered_h0->mutable_data<T>(ctx.GetPlace());
@@ -415,6 +418,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
415418
to_seq(dev_ctx, *batched_out, hidden_out);
416419
}
417420
#undef INIT_VEC_FUNC
421+
#undef INIT_BASE_SIZES
422+
#undef INIT_BASE_INPUT_OUTPUT
418423
};
419424

420425
} // namespace operators

0 commit comments

Comments
 (0)