Skip to content

Commit 5377edd

Browse files
committed
refine packed condition
1 parent 3bf3e77 commit 5377edd

File tree

2 files changed

+79
-59
lines changed

2 files changed

+79
-59
lines changed

paddle/fluid/operators/gru_op.cc

Lines changed: 79 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/gru_op.h"
1616
#include <string>
17+
#include "paddle/fluid/operators/math/blas.h"
18+
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
19+
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
20+
21+
DECLARE_int32(paddle_num_threads);
1722

1823
namespace paddle {
1924
namespace operators {
@@ -264,76 +269,94 @@ class GRUCPUKernel : public framework::OpKernel<T> {
264269
gru_value.prev_out_value = nullptr;
265270
}
266271
auto batch_starts = batch_gate->lod()[0];
267-
size_t num_batch = batch_starts.size() - 1;
272+
size_t seq_len = batch_starts.size() - 1;
268273
auto active_node = math::detail::GetActivationType(
269274
context.Attr<std::string>("activation"));
270275
auto active_gate = math::detail::GetActivationType(
271276
context.Attr<std::string>("gate_activation"));
272277

273278
#ifdef PADDLE_WITH_MKLML
274-
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
275-
// TODO(TJ): make a class
276-
T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
277-
frame_size * 2 /*width of weight*/,
278-
frame_size /*height of height*/);
279-
PADDLE_ENFORCE(packed_gate);
280-
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2,
281-
frame_size, T(1.0), gru_value.gate_weight, frame_size * 2,
282-
packed_gate);
283-
T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
284-
frame_size /*width of weight*/,
285-
frame_size /*height of height*/);
286-
PADDLE_ENFORCE(packed_state);
287-
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size,
288-
frame_size, T(1.0), gru_value.state_weight, frame_size,
289-
packed_state);
290-
#endif
291-
for (size_t n = 0; n < num_batch; n++) {
292-
int bstart = static_cast<int>(batch_starts[n]);
293-
int bend = static_cast<int>(batch_starts[n + 1]);
294-
int cur_batch_size = bend - bstart;
295-
296-
Tensor gate_t = batch_gate->Slice(bstart, bend);
297-
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
298-
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
299-
gru_value.output_value = hidden_t.data<T>();
300-
gru_value.gate_value = gate_t.data<T>();
301-
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
279+
if (FLAGS_paddle_num_threads >= 4) {
280+
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
281+
T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
282+
frame_size * 2 /*width of weight*/,
283+
frame_size /*height of height*/);
284+
PADDLE_ENFORCE(packed_gate);
285+
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2,
286+
frame_size, T(1.0), gru_value.gate_weight, frame_size * 2,
287+
packed_gate);
288+
T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
289+
frame_size /*width of weight*/,
290+
frame_size /*height of height*/);
291+
PADDLE_ENFORCE(packed_state);
292+
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size,
293+
frame_size, T(1.0), gru_value.state_weight, frame_size,
294+
packed_state);
295+
for (size_t n = 0; n < seq_len; n++) {
296+
int bstart = static_cast<int>(batch_starts[n]);
297+
int bend = static_cast<int>(batch_starts[n + 1]);
298+
int cur_batch_size = bend - bstart;
302299

303-
#ifdef PADDLE_WITH_MKLML
304-
if (gru_value.prev_out_value) {
305-
blas.GEMM_COMPUTE(CblasNoTrans, CblasPacked, cur_batch_size,
306-
frame_size * 2, frame_size, gru_value.prev_out_value,
307-
frame_size, packed_gate, frame_size * 2, T(1),
308-
gru_value.gate_value, frame_size * 3);
309-
}
300+
Tensor gate_t = batch_gate->Slice(bstart, bend);
301+
Tensor reset_hidden_prev_t =
302+
batch_reset_hidden_prev->Slice(bstart, bend);
303+
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
304+
gru_value.output_value = hidden_t.data<T>();
305+
gru_value.gate_value = gate_t.data<T>();
306+
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
310307

311-
math::detail::forward_reset_output(
312-
math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
313-
cur_batch_size, active_gate);
308+
if (gru_value.prev_out_value) {
309+
blas.GEMM_COMPUTE(
310+
CblasNoTrans, CblasPacked, cur_batch_size, frame_size * 2,
311+
frame_size, gru_value.prev_out_value, frame_size, packed_gate,
312+
frame_size * 2, T(1), gru_value.gate_value, frame_size * 3);
313+
}
314314

315-
if (gru_value.prev_out_value) {
316-
blas.GEMM_COMPUTE(
317-
CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size,
318-
gru_value.reset_output_value, frame_size, packed_state, frame_size,
319-
T(1), gru_value.gate_value + frame_size * 2, frame_size * 3);
315+
math::detail::forward_reset_output(
316+
math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
317+
cur_batch_size, active_gate);
318+
319+
if (gru_value.prev_out_value) {
320+
blas.GEMM_COMPUTE(
321+
CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size,
322+
gru_value.reset_output_value, frame_size, packed_state,
323+
frame_size, T(1), gru_value.gate_value + frame_size * 2,
324+
frame_size * 3);
325+
}
326+
327+
math::detail::forward_final_output(
328+
math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
329+
cur_batch_size, active_node);
330+
331+
gru_value.prev_out_value = gru_value.output_value;
320332
}
321333

322-
math::detail::forward_final_output(
323-
math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
324-
cur_batch_size, active_node);
325-
#else
326-
math::GRUUnitFunctor<DeviceContext, T>::compute(
327-
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
328-
active_gate);
334+
blas.GEMM_FREE(packed_gate);
335+
blas.GEMM_FREE(packed_state);
336+
} else {
329337
#endif
330-
gru_value.prev_out_value = gru_value.output_value;
331-
}
338+
for (size_t n = 0; n < seq_len; n++) {
339+
int bstart = static_cast<int>(batch_starts[n]);
340+
int bend = static_cast<int>(batch_starts[n + 1]);
341+
int cur_batch_size = bend - bstart;
342+
343+
Tensor gate_t = batch_gate->Slice(bstart, bend);
344+
Tensor reset_hidden_prev_t =
345+
batch_reset_hidden_prev->Slice(bstart, bend);
346+
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
347+
gru_value.output_value = hidden_t.data<T>();
348+
gru_value.gate_value = gate_t.data<T>();
349+
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
350+
351+
math::GRUUnitFunctor<DeviceContext, T>::compute(
352+
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
353+
active_gate);
354+
355+
gru_value.prev_out_value = gru_value.output_value;
356+
}
332357
#ifdef PADDLE_WITH_MKLML
333-
blas.GEMM_FREE(packed_gate);
334-
blas.GEMM_FREE(packed_state);
358+
}
335359
#endif
336-
337360
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
338361
batch_hidden->set_lod(batch_gate->lod());
339362
to_seq(dev_ctx, *batch_hidden, hidden);

paddle/fluid/operators/gru_op.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@ limitations under the License. */
1616
#include <string>
1717
#include "paddle/fluid/framework/eigen.h"
1818
#include "paddle/fluid/framework/op_registry.h"
19-
#include "paddle/fluid/operators/math/blas.h"
2019
#include "paddle/fluid/operators/math/detail/activation_functions.h"
21-
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
22-
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
2320
#include "paddle/fluid/operators/math/gru_compute.h"
2421
#include "paddle/fluid/operators/math/math_function.h"
2522
#include "paddle/fluid/operators/math/sequence2batch.h"

0 commit comments

Comments
 (0)