Skip to content

Commit 8c23f7c

Browse files
committed
fix blas and use packed weight
1 parent d9cc6b1 commit 8c23f7c

File tree

2 files changed

+28
-8
lines changed

2 files changed

+28
-8
lines changed

paddle/fluid/operators/gru_op.h

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,23 @@ class GRUKernel : public framework::OpKernel<T> {
9898
auto active_gate = math::detail::GetActivationType(
9999
context.Attr<std::string>("gate_activation"));
100100
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
101+
102+
// TODO(TJ): make a class, make one pack
103+
T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
104+
frame_size * 2 /*width of weight*/,
105+
frame_size /*height of height*/);
106+
PADDLE_ENFORCE(packed_gate);
107+
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2,
108+
frame_size, T(1.0), gru_value.gate_weight, frame_size * 2,
109+
packed_gate);
110+
T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
111+
frame_size /*width of weight*/,
112+
frame_size /*height of height*/);
113+
PADDLE_ENFORCE(packed_state);
114+
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size,
115+
frame_size, T(1.0), gru_value.state_weight, frame_size,
116+
packed_state);
117+
101118
for (size_t n = 0; n < num_batch; n++) {
102119
int bstart = static_cast<int>(batch_starts[n]);
103120
int bend = static_cast<int>(batch_starts[n + 1]);
@@ -110,20 +127,21 @@ class GRUKernel : public framework::OpKernel<T> {
110127
gru_value.gate_value = gate_t.data<T>();
111128
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
112129
if (gru_value.prev_out_value) {
113-
blas.GEMM(false, false, cur_batch_size, frame_size * 2, frame_size, 1,
114-
gru_value.prev_out_value, frame_size, gru_value.gate_weight,
115-
frame_size * 2, 1, gru_value.gate_value, frame_size * 3);
130+
blas.GEMM_COMPUTE(CblasNoTrans, CblasPacked, cur_batch_size,
131+
frame_size * 2, frame_size, gru_value.prev_out_value,
132+
frame_size, packed_gate, frame_size * 2, T(1),
133+
gru_value.gate_value, frame_size * 3);
116134
}
117135

118136
math::detail::forward_reset_output(
119137
math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
120138
cur_batch_size, active_gate);
121139

122140
if (gru_value.prev_out_value) {
123-
blas.GEMM(false, false, cur_batch_size, frame_size, frame_size, 1,
124-
gru_value.reset_output_value, frame_size,
125-
gru_value.state_weight, frame_size, 1,
126-
gru_value.gate_value + frame_size * 2, frame_size * 3);
141+
blas.GEMM_COMPUTE(
142+
CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size,
143+
gru_value.reset_output_value, frame_size, packed_state, frame_size,
144+
T(1), gru_value.gate_value + frame_size * 2, frame_size * 3);
127145
}
128146

129147
math::detail::forward_final_output(
@@ -132,6 +150,8 @@ class GRUKernel : public framework::OpKernel<T> {
132150

133151
gru_value.prev_out_value = gru_value.output_value;
134152
}
153+
blas.GEMM_FREE(packed_gate);
154+
blas.GEMM_FREE(packed_state);
135155

136156
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
137157
batch_hidden->set_lod(batch_gate->lod());

paddle/fluid/operators/math/blas.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ class BlasT : private Blas<DeviceContext> {
165165

166166
template <typename... ARGS>
167167
T* GEMM_ALLOC(ARGS... args) const {
168-
Base()->template GEMM_ALLOC<T>(args...);
168+
return Base()->template GEMM_ALLOC<T>(args...);
169169
}
170170

171171
template <typename... ARGS>

0 commit comments

Comments
 (0)