Skip to content

Commit 87086b1

Browse files
yihuaxutensor-tang
authored andcommitted
Refine activation for GRU operator (#13275)
* Optimize GRU with AVX instruction * Clean code * Add the Unitest and fix the align issue * Remove the remanent part of the unitest part * Code clean * Fix the parameters length issue for fusion_gru to pass CI * Change the default type as float32
1 parent d402234 commit 87086b1

File tree

2 files changed

+109
-37
lines changed

2 files changed

+109
-37
lines changed

paddle/fluid/operators/math/detail/gru_cpu_kernel.h

Lines changed: 89 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -85,26 +85,59 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
8585
T *prev_output_value, int frame_size,
8686
ActivationType active_gate) {
8787
#ifdef __AVX__
88-
__m256 r_value_update_gate;
89-
__m256 r_value_reset_gate;
88+
__m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f);
89+
__m256 r_value_reset_gate, r_value_reset_gate_last = _mm256_set1_ps(0.0f);
9090
__m256 r_value_reset_output;
91-
__m256 r_prev_out = _mm256_set1_ps(0.0f);
92-
__m256 *update_gate = reinterpret_cast<__m256 *>(gate_value);
93-
__m256 *reset_gate = reinterpret_cast<__m256 *>(gate_value + frame_size);
91+
__m256 r_prev_out = _mm256_set1_ps(0.0f),
92+
r_prev_out_last = _mm256_set1_ps(0.0f);
93+
T *update_gate = gate_value;
94+
T *reset_gate = gate_value + frame_size;
95+
int block = 8;
96+
const int n = frame_size;
97+
const int rest = n % block;
98+
const int end = n - rest;
99+
int i = 0;
100+
101+
if (rest > 0) {
102+
i = n - block;
103+
r_value_update_gate_last =
104+
_mm256_loadu_ps((const float *)(update_gate + i));
105+
r_value_reset_gate_last = _mm256_loadu_ps((const float *)(reset_gate + i));
106+
if (prev_output_value) {
107+
r_prev_out_last = _mm256_loadu_ps((const float *)(prev_output_value + i));
108+
}
109+
}
94110

95-
for (int i = 0; i < frame_size / 8; i++) {
96-
r_value_update_gate = update_gate[i];
97-
r_value_reset_gate = reset_gate[i];
111+
for (i = 0; i < end; i += block) {
112+
r_value_update_gate = _mm256_loadu_ps((const float *)(update_gate + i));
113+
r_value_reset_gate = _mm256_loadu_ps((const float *)(reset_gate + i));
98114
if (prev_output_value) {
99-
r_prev_out = (reinterpret_cast<__m256 *>(prev_output_value))[i];
115+
r_prev_out = _mm256_loadu_ps((const float *)(prev_output_value + i));
100116
}
101117

102118
op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out,
103119
&r_value_reset_output, active_gate);
104120

105-
update_gate[i] = r_value_update_gate;
106-
reset_gate[i] = r_value_reset_gate;
107-
(reinterpret_cast<__m256 *>(reset_output_value))[i] = r_value_reset_output;
121+
_mm256_storeu_ps(reinterpret_cast<float *>(update_gate + i),
122+
r_value_update_gate);
123+
_mm256_storeu_ps(reinterpret_cast<float *>(reset_gate + i),
124+
r_value_reset_gate);
125+
_mm256_storeu_ps(reinterpret_cast<float *>(reset_output_value + i),
126+
r_value_reset_output);
127+
}
128+
129+
if (rest > 0) {
130+
i = n - block;
131+
132+
op_reset_output(&r_value_update_gate_last, &r_value_reset_gate_last,
133+
&r_prev_out_last, &r_value_reset_output, active_gate);
134+
135+
_mm256_storeu_ps(reinterpret_cast<float *>(update_gate + i),
136+
r_value_update_gate_last);
137+
_mm256_storeu_ps(reinterpret_cast<float *>(reset_gate + i),
138+
r_value_reset_gate_last);
139+
_mm256_storeu_ps(reinterpret_cast<float *>(reset_output_value + i),
140+
r_value_reset_output);
108141
}
109142
#endif
110143
}
@@ -115,26 +148,55 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
115148
T *output_value, int frame_size,
116149
ActivationType active_node) {
117150
#ifdef __AVX__
118-
__m256 r_value_update_gate;
119-
__m256 r_value_frame_state;
120-
__m256 r_prev_out = _mm256_set1_ps(0.0f);
151+
__m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f);
152+
__m256 r_value_frame_state, r_value_frame_state_last = _mm256_set1_ps(0.0f);
153+
__m256 r_prev_out = _mm256_set1_ps(0.0f),
154+
r_prev_out_last = _mm256_set1_ps(0.0f);
121155
__m256 r_output;
122-
__m256 *update_gate = reinterpret_cast<__m256 *>(gate_value);
123-
__m256 *frame_state = reinterpret_cast<__m256 *>(gate_value + frame_size * 2);
156+
T *update_gate = gate_value;
157+
T *frame_state = gate_value + frame_size * 2;
158+
int block = 8;
159+
const int n = frame_size;
160+
const int rest = n % block;
161+
const int end = n - rest;
162+
int i = 0;
163+
164+
if (rest > 0) {
165+
i = n - block;
166+
r_value_update_gate_last =
167+
_mm256_loadu_ps((const float *)(update_gate + i));
168+
r_value_frame_state_last =
169+
_mm256_loadu_ps((const float *)(frame_state + i));
170+
if (prev_output_value) {
171+
r_prev_out_last = _mm256_loadu_ps((const float *)(prev_output_value + i));
172+
}
173+
}
124174

125-
for (int i = 0; i < frame_size / 8; i++) {
126-
r_value_update_gate = update_gate[i];
127-
r_value_frame_state = frame_state[i];
175+
for (i = 0; i < end; i += block) {
176+
r_value_update_gate = _mm256_loadu_ps((const float *)(update_gate + i));
177+
r_value_frame_state = _mm256_loadu_ps((const float *)(frame_state + i));
128178
if (prev_output_value) {
129-
r_prev_out = (reinterpret_cast<__m256 *>(prev_output_value))[i];
179+
r_prev_out = _mm256_loadu_ps((const float *)(prev_output_value + i));
130180
}
131181

132182
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
133183
&r_output, active_node);
134184

135-
frame_state[i] = r_value_frame_state;
136-
(reinterpret_cast<__m256 *>(output_value))[i] = r_output;
185+
_mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i),
186+
r_value_frame_state);
187+
_mm256_storeu_ps(reinterpret_cast<float *>(output_value + i), r_output);
188+
}
189+
190+
if (rest > 0) {
191+
i = n - block;
192+
op_final_output(&r_value_update_gate_last, &r_value_frame_state_last,
193+
&r_prev_out_last, &r_output, active_node);
194+
195+
_mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i),
196+
r_value_frame_state_last);
197+
_mm256_storeu_ps(reinterpret_cast<float *>(output_value + i), r_output);
137198
}
199+
138200
#endif
139201
}
140202

@@ -143,7 +205,8 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
143205
GRUMetaValue<T> value, int frame_size,
144206
int batch_size, ActivationType active_gate) {
145207
for (int b = 0; b < batch_size; b++) {
146-
if (OpResetOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
208+
if (OpResetOutput::avx && (frame_size > static_cast<int>(8 - 1)) &&
209+
(sizeof(T) == 4)) {
147210
hl_avx_gru_forward_reset_output(
148211
op_reset_output, value.gate_value, value.reset_output_value,
149212
value.prev_out_value, frame_size, active_gate);
@@ -166,7 +229,8 @@ inline void forward_final_output(OpFinalOutput op_final_output,
166229
GRUMetaValue<T> value, int frame_size,
167230
int batch_size, ActivationType active_node) {
168231
for (int b = 0; b < batch_size; b++) {
169-
if (OpFinalOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
232+
if (OpFinalOutput::avx && (frame_size > static_cast<int>(8 - 1)) &&
233+
(sizeof(T) == 4)) {
170234
hl_avx_gru_forward_final_output(op_final_output, value.gate_value,
171235
value.prev_out_value, value.output_value,
172236
frame_size, active_node);

python/paddle/fluid/tests/unittests/test_gru_op.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ def gru(
3030
bias, # 1 x 3D
3131
is_reverse,
3232
act_state,
33-
act_gate):
33+
act_gate,
34+
dtype='float32'):
3435
def _seq_to_batch(lod, is_reverse):
3536
idx_in_seq_list = []
3637
seq_lens = lod[0]
@@ -71,10 +72,10 @@ def _step(x, h_p, w, b, act_state, act_gate):
7172
T = sum(lod[0])
7273
N = len(lod[0])
7374
D = weight.shape[0]
74-
batch_gate = np.zeros((T, 3 * D), dtype='float64')
75-
batch_reset_hidden_prev = np.zeros((T, D), dtype='float64')
76-
batch_hidden = np.zeros((T, D), dtype='float64')
77-
hidden = np.zeros((T, D), dtype='float64')
75+
batch_gate = np.zeros((T, 3 * D), dtype=dtype)
76+
batch_reset_hidden_prev = np.zeros((T, D), dtype=dtype)
77+
batch_hidden = np.zeros((T, D), dtype=dtype)
78+
hidden = np.zeros((T, D), dtype=dtype)
7879

7980
idx_in_seq_list, sorted_seqs = _seq_to_batch(lod, is_reverse)
8081
h_p = h0[sorted_seqs]
@@ -108,23 +109,24 @@ def setUp(self):
108109
self.with_bias = True
109110
self.act_state = 'tanh'
110111
self.act_gate = 'sigmoid'
112+
self.dtype = 'float64'
111113
self.set_confs()
112114

113115
T = sum(self.lod[0])
114116
N = len(self.lod[0])
115117

116-
input = np.random.rand(T, 3 * self.D).astype('float64')
117-
weight = np.random.rand(self.D, 3 * self.D).astype('float64')
118+
input = np.random.rand(T, 3 * self.D).astype(self.dtype)
119+
weight = np.random.rand(self.D, 3 * self.D).astype(self.dtype)
118120
bias = np.random.rand(
119-
1, 3 * self.D).astype('float64') if self.with_bias else np.zeros(
120-
(1, 3 * self.D), dtype='float64')
121+
1, 3 * self.D).astype(self.dtype) if self.with_bias else np.zeros(
122+
(1, 3 * self.D), dtype=self.dtype)
121123
h0 = np.random.rand(
122-
N, self.D).astype('float64') if self.with_h0 else np.zeros(
123-
(N, self.D), dtype='float64')
124+
N, self.D).astype(self.dtype) if self.with_h0 else np.zeros(
125+
(N, self.D), dtype=self.dtype)
124126

125127
batch_gate, batch_reset_hidden_prev, batch_hidden, hidden = gru(
126128
input, self.lod, h0, weight, bias, self.is_reverse,
127-
ACTIVATION[self.act_state], ACTIVATION[self.act_gate])
129+
ACTIVATION[self.act_state], ACTIVATION[self.act_gate], self.dtype)
128130
self.inputs = {'Input': (input, self.lod), 'Weight': weight}
129131

130132
if self.with_bias:
@@ -153,6 +155,12 @@ def test_check_grad(self):
153155
self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden'])
154156

155157

158+
class TestGRUOp2(TestGRUOp):
159+
def set_confs(self):
160+
self.D = 19
161+
self.dtype = 'float32'
162+
163+
156164
class TestGRUOpNoInitial(TestGRUOp):
157165
def set_confs(self):
158166
self.with_h0 = False

0 commit comments

Comments
 (0)