@@ -85,26 +85,59 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
85
85
T *prev_output_value, int frame_size,
86
86
ActivationType active_gate) {
87
87
#ifdef __AVX__
88
- __m256 r_value_update_gate;
89
- __m256 r_value_reset_gate;
88
+ __m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps ( 0 . 0f ) ;
89
+ __m256 r_value_reset_gate, r_value_reset_gate_last = _mm256_set1_ps ( 0 . 0f ) ;
90
90
__m256 r_value_reset_output;
91
- __m256 r_prev_out = _mm256_set1_ps (0 .0f );
92
- __m256 *update_gate = reinterpret_cast <__m256 *>(gate_value);
93
- __m256 *reset_gate = reinterpret_cast <__m256 *>(gate_value + frame_size);
91
+ __m256 r_prev_out = _mm256_set1_ps (0 .0f ),
92
+ r_prev_out_last = _mm256_set1_ps (0 .0f );
93
+ T *update_gate = gate_value;
94
+ T *reset_gate = gate_value + frame_size;
95
+ int block = 8 ;
96
+ const int n = frame_size;
97
+ const int rest = n % block;
98
+ const int end = n - rest;
99
+ int i = 0 ;
100
+
101
+ if (rest > 0 ) {
102
+ i = n - block;
103
+ r_value_update_gate_last =
104
+ _mm256_loadu_ps ((const float *)(update_gate + i));
105
+ r_value_reset_gate_last = _mm256_loadu_ps ((const float *)(reset_gate + i));
106
+ if (prev_output_value) {
107
+ r_prev_out_last = _mm256_loadu_ps ((const float *)(prev_output_value + i));
108
+ }
109
+ }
94
110
95
- for (int i = 0 ; i < frame_size / 8 ; i++ ) {
96
- r_value_update_gate = update_gate[i] ;
97
- r_value_reset_gate = reset_gate[i] ;
111
+ for (i = 0 ; i < end ; i += block ) {
112
+ r_value_update_gate = _mm256_loadu_ps (( const float *)( update_gate + i)) ;
113
+ r_value_reset_gate = _mm256_loadu_ps (( const float *)( reset_gate + i)) ;
98
114
if (prev_output_value) {
99
- r_prev_out = ( reinterpret_cast <__m256 *> (prev_output_value))[i] ;
115
+ r_prev_out = _mm256_loadu_ps (( const float *) (prev_output_value + i)) ;
100
116
}
101
117
102
118
op_reset_output (&r_value_update_gate, &r_value_reset_gate, &r_prev_out,
103
119
&r_value_reset_output, active_gate);
104
120
105
- update_gate[i] = r_value_update_gate;
106
- reset_gate[i] = r_value_reset_gate;
107
- (reinterpret_cast <__m256 *>(reset_output_value))[i] = r_value_reset_output;
121
+ _mm256_storeu_ps (reinterpret_cast <float *>(update_gate + i),
122
+ r_value_update_gate);
123
+ _mm256_storeu_ps (reinterpret_cast <float *>(reset_gate + i),
124
+ r_value_reset_gate);
125
+ _mm256_storeu_ps (reinterpret_cast <float *>(reset_output_value + i),
126
+ r_value_reset_output);
127
+ }
128
+
129
+ if (rest > 0 ) {
130
+ i = n - block;
131
+
132
+ op_reset_output (&r_value_update_gate_last, &r_value_reset_gate_last,
133
+ &r_prev_out_last, &r_value_reset_output, active_gate);
134
+
135
+ _mm256_storeu_ps (reinterpret_cast <float *>(update_gate + i),
136
+ r_value_update_gate_last);
137
+ _mm256_storeu_ps (reinterpret_cast <float *>(reset_gate + i),
138
+ r_value_reset_gate_last);
139
+ _mm256_storeu_ps (reinterpret_cast <float *>(reset_output_value + i),
140
+ r_value_reset_output);
108
141
}
109
142
#endif
110
143
}
@@ -115,26 +148,55 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
115
148
T *output_value, int frame_size,
116
149
ActivationType active_node) {
117
150
#ifdef __AVX__
118
- __m256 r_value_update_gate;
119
- __m256 r_value_frame_state;
120
- __m256 r_prev_out = _mm256_set1_ps (0 .0f );
151
+ __m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps (0 .0f );
152
+ __m256 r_value_frame_state, r_value_frame_state_last = _mm256_set1_ps (0 .0f );
153
+ __m256 r_prev_out = _mm256_set1_ps (0 .0f ),
154
+ r_prev_out_last = _mm256_set1_ps (0 .0f );
121
155
__m256 r_output;
122
- __m256 *update_gate = reinterpret_cast <__m256 *>(gate_value);
123
- __m256 *frame_state = reinterpret_cast <__m256 *>(gate_value + frame_size * 2 );
156
+ T *update_gate = gate_value;
157
+ T *frame_state = gate_value + frame_size * 2 ;
158
+ int block = 8 ;
159
+ const int n = frame_size;
160
+ const int rest = n % block;
161
+ const int end = n - rest;
162
+ int i = 0 ;
163
+
164
+ if (rest > 0 ) {
165
+ i = n - block;
166
+ r_value_update_gate_last =
167
+ _mm256_loadu_ps ((const float *)(update_gate + i));
168
+ r_value_frame_state_last =
169
+ _mm256_loadu_ps ((const float *)(frame_state + i));
170
+ if (prev_output_value) {
171
+ r_prev_out_last = _mm256_loadu_ps ((const float *)(prev_output_value + i));
172
+ }
173
+ }
124
174
125
- for (int i = 0 ; i < frame_size / 8 ; i++ ) {
126
- r_value_update_gate = update_gate[i] ;
127
- r_value_frame_state = frame_state[i] ;
175
+ for (i = 0 ; i < end ; i += block ) {
176
+ r_value_update_gate = _mm256_loadu_ps (( const float *)( update_gate + i)) ;
177
+ r_value_frame_state = _mm256_loadu_ps (( const float *)( frame_state + i)) ;
128
178
if (prev_output_value) {
129
- r_prev_out = ( reinterpret_cast <__m256 *> (prev_output_value))[i] ;
179
+ r_prev_out = _mm256_loadu_ps (( const float *) (prev_output_value + i)) ;
130
180
}
131
181
132
182
op_final_output (&r_value_update_gate, &r_value_frame_state, &r_prev_out,
133
183
&r_output, active_node);
134
184
135
- frame_state[i] = r_value_frame_state;
136
- (reinterpret_cast <__m256 *>(output_value))[i] = r_output;
185
+ _mm256_storeu_ps (reinterpret_cast <float *>(frame_state + i),
186
+ r_value_frame_state);
187
+ _mm256_storeu_ps (reinterpret_cast <float *>(output_value + i), r_output);
188
+ }
189
+
190
+ if (rest > 0 ) {
191
+ i = n - block;
192
+ op_final_output (&r_value_update_gate_last, &r_value_frame_state_last,
193
+ &r_prev_out_last, &r_output, active_node);
194
+
195
+ _mm256_storeu_ps (reinterpret_cast <float *>(frame_state + i),
196
+ r_value_frame_state_last);
197
+ _mm256_storeu_ps (reinterpret_cast <float *>(output_value + i), r_output);
137
198
}
199
+
138
200
#endif
139
201
}
140
202
@@ -143,7 +205,8 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
143
205
GRUMetaValue<T> value, int frame_size,
144
206
int batch_size, ActivationType active_gate) {
145
207
for (int b = 0 ; b < batch_size; b++) {
146
- if (OpResetOutput::avx && !(frame_size & (8 - 1 )) && (sizeof (T) == 4 )) {
208
+ if (OpResetOutput::avx && (frame_size > static_cast <int >(8 - 1 )) &&
209
+ (sizeof (T) == 4 )) {
147
210
hl_avx_gru_forward_reset_output (
148
211
op_reset_output, value.gate_value , value.reset_output_value ,
149
212
value.prev_out_value , frame_size, active_gate);
@@ -166,7 +229,8 @@ inline void forward_final_output(OpFinalOutput op_final_output,
166
229
GRUMetaValue<T> value, int frame_size,
167
230
int batch_size, ActivationType active_node) {
168
231
for (int b = 0 ; b < batch_size; b++) {
169
- if (OpFinalOutput::avx && !(frame_size & (8 - 1 )) && (sizeof (T) == 4 )) {
232
+ if (OpFinalOutput::avx && (frame_size > static_cast <int >(8 - 1 )) &&
233
+ (sizeof (T) == 4 )) {
170
234
hl_avx_gru_forward_final_output (op_final_output, value.gate_value ,
171
235
value.prev_out_value , value.output_value ,
172
236
frame_size, active_node);
0 commit comments