@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#include " paddle/fluid/operators/math/jit_code.h"
16
- #include " paddle/fluid/operators/math/jit_kernel.h"
17
- #include " paddle/fluid/platform/cpu_info.h"
16
+ #include " paddle/fluid/operators/math/jit_kernel.h" // TODO(TJ): remove me
18
17
19
18
namespace paddle {
20
19
namespace operators {
@@ -111,60 +110,26 @@ void VXXJitCode::generate() {
111
110
ret ();
112
111
}
113
112
114
- #define ALIGN32 __attribute__ ((aligned(32 )))
115
- #define EXP_HIG 88 .3762626647949f
116
- #define EXP_LOW -88 .3762626647949f
117
- #define CEPHES_LOG2EF 1.44269504088896341
118
- #define CEPHES_EXP_C1 0.693359375
119
- #define CEPHES_EXP_C2 -2.12194440e-4
120
- #define CEPHES_EXP_P0 1.9875691500E-4
121
- #define CEPHES_EXP_P1 1.3981999507E-3
122
- #define CEPHES_EXP_P2 8.3334519073E-3
123
- #define CEPHES_EXP_P3 4.1665795894E-2
124
- #define CEPHES_EXP_P4 1.6666665459E-1
125
- #define CEPHES_EXP_P5 5.0000001201E-1
113
+ const float exp_float_consts[] ALIGN32 = {REPEAT_8TIMES (1 .f ),
114
+ REPEAT_8TIMES (2 .f ),
115
+ REPEAT_8TIMES (0 .5f ),
116
+ REPEAT_8TIMES (EXP_HIG),
117
+ REPEAT_8TIMES (EXP_LOW),
118
+ REPEAT_8TIMES (CEPHES_LOG2EF),
119
+ REPEAT_8TIMES (CEPHES_EXP_C1),
120
+ REPEAT_8TIMES (CEPHES_EXP_C2),
121
+ REPEAT_8TIMES (CEPHES_EXP_P0),
122
+ REPEAT_8TIMES (CEPHES_EXP_P1),
123
+ REPEAT_8TIMES (CEPHES_EXP_P2),
124
+ REPEAT_8TIMES (CEPHES_EXP_P3),
125
+ REPEAT_8TIMES (CEPHES_EXP_P4),
126
+ REPEAT_8TIMES (CEPHES_EXP_P5),
127
+ REPEAT_8TIMES (EXP_MAX_INPUT),
128
+ REPEAT_8TIMES (SIGMOID_THRESHOLD_MAX),
129
+ REPEAT_8TIMES (SIGMOID_THRESHOLD_MIN)};
126
130
127
- #define REPEAT_8TIMES (val ) val, val, val, val, val, val, val, val
128
-
129
- #define OFFSET_EXP_ONE 0 * YMM_FLOAT_BLOCK * sizeof (float )
130
- #define OFFSET_EXP_TWO 1 * YMM_FLOAT_BLOCK * sizeof (float )
131
- #define OFFSET_EXP_0P5 2 * YMM_FLOAT_BLOCK * sizeof (float )
132
- #define OFFSET_EXP_HIG 3 * YMM_FLOAT_BLOCK * sizeof (float )
133
- #define OFFSET_EXP_LOW 4 * YMM_FLOAT_BLOCK * sizeof (float )
134
- #define OFFSET_EXP_LOG2EF 5 * YMM_FLOAT_BLOCK * sizeof (float )
135
- #define OFFSET_EXP_C1 6 * YMM_FLOAT_BLOCK * sizeof (float )
136
- #define OFFSET_EXP_C2 7 * YMM_FLOAT_BLOCK * sizeof (float )
137
- #define OFFSET_EXP_P0 8 * YMM_FLOAT_BLOCK * sizeof (float )
138
- #define OFFSET_EXP_P1 9 * YMM_FLOAT_BLOCK * sizeof (float )
139
- #define OFFSET_EXP_P2 10 * YMM_FLOAT_BLOCK * sizeof (float )
140
- #define OFFSET_EXP_P3 11 * YMM_FLOAT_BLOCK * sizeof (float )
141
- #define OFFSET_EXP_P4 12 * YMM_FLOAT_BLOCK * sizeof (float )
142
- #define OFFSET_EXP_P5 13 * YMM_FLOAT_BLOCK * sizeof (float )
143
- #define OFFSET_EXP_MAX_INPUT 14 * YMM_FLOAT_BLOCK * sizeof (float )
144
- #define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof (float )
145
- #define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof (float )
146
-
147
- static const float exp_float_consts[] ALIGN32 = {
148
- REPEAT_8TIMES (1 .f ),
149
- REPEAT_8TIMES (2 .f ),
150
- REPEAT_8TIMES (0 .5f ),
151
- REPEAT_8TIMES (EXP_HIG),
152
- REPEAT_8TIMES (EXP_LOW),
153
- REPEAT_8TIMES (CEPHES_LOG2EF),
154
- REPEAT_8TIMES (CEPHES_EXP_C1),
155
- REPEAT_8TIMES (CEPHES_EXP_C2),
156
- REPEAT_8TIMES (CEPHES_EXP_P0),
157
- REPEAT_8TIMES (CEPHES_EXP_P1),
158
- REPEAT_8TIMES (CEPHES_EXP_P2),
159
- REPEAT_8TIMES (CEPHES_EXP_P3),
160
- REPEAT_8TIMES (CEPHES_EXP_P4),
161
- REPEAT_8TIMES (CEPHES_EXP_P5),
162
- REPEAT_8TIMES (EXP_MAX_INPUT),
163
- REPEAT_8TIMES (SIGMOID_THRESHOLD_MAX),
164
- REPEAT_8TIMES (SIGMOID_THRESHOLD_MIN)};
165
-
166
- static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES (0x7f )};
167
- static int g_tmp_mem[16 ] ALIGN32 = {0 };
131
+ const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES (0x7f )};
132
+ int g_tmp_mem[16 ] ALIGN32 = {0 };
168
133
169
134
bool VActJitCode::init (int d, operand_type type) {
170
135
bool ok = MayIUse (avx);
@@ -177,146 +142,6 @@ bool VActJitCode::init(int d, operand_type type) {
177
142
}
178
143
}
179
144
180
- void VActJitCode::exp_ymm (ymm_t & ymm_dst, ymm_t & ymm_src, int fx_idx,
181
- int fy_idx, int mask_idx, int tmp_idx) {
182
- assert (ymm_src.getIdx () != ymm_dst.getIdx ()); // TODO(TJ): use enfore
183
- // check all idx can not equal
184
- ymm_t ymm_fx = ymm_t (fx_idx);
185
- ymm_t ymm_fy = ymm_t (fy_idx);
186
- ymm_t ymm_mask = ymm_t (mask_idx);
187
- ymm_t ymm_tmp = ymm_t (tmp_idx);
188
- reg64_t reg_ptr_global = rax;
189
- push (reg_ptr_global);
190
- mov (reg_ptr_global, reinterpret_cast <size_t >(exp_float_consts));
191
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
192
- vminps (ymm_src, ymm_src, ymm_tmp);
193
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]);
194
- vmaxps (ymm_src, ymm_src, ymm_tmp);
195
- // express exp(x) as exp(g + n*log(2))
196
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]);
197
- vmulps (ymm_fx, ymm_src, ymm_tmp);
198
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]);
199
- vaddps (ymm_fx, ymm_fx, ymm_tmp);
200
- vroundps (ymm_fy, ymm_fx, 0x01 );
201
- // if greater, substract 1
202
- vcmpgtps (ymm_mask, ymm_fy, ymm_fx);
203
- vmovaps (ymm_tmp, ptr[reg_ptr_global]);
204
- vandps (ymm_mask, ymm_mask, ymm_tmp);
205
- vsubps (ymm_fx, ymm_fy, ymm_mask);
206
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]);
207
- vmulps (ymm_fy, ymm_fx, ymm_tmp);
208
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
209
- ymm_t ymm_z = ymm_t (ymm_mask.getIdx ());
210
- vmulps (ymm_z, ymm_fx, ymm_tmp);
211
- vsubps (ymm_src, ymm_src, ymm_fy);
212
- vsubps (ymm_src, ymm_src, ymm_z);
213
- vmulps (ymm_z, ymm_src, ymm_src);
214
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
215
- vmulps (ymm_dst, ymm_src, ymm_tmp);
216
- for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
217
- i += (YMM_FLOAT_BLOCK * sizeof (float ))) {
218
- vmovaps (ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4
219
- vaddps (ymm_dst, ymm_dst, ymm_tmp);
220
- vmulps (ymm_dst, ymm_dst, ymm_src);
221
- }
222
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]);
223
- vaddps (ymm_dst, ymm_dst, ymm_tmp);
224
- vmulps (ymm_dst, ymm_dst, ymm_z);
225
- vaddps (ymm_dst, ymm_dst, ymm_src);
226
- vmovaps (ymm_tmp, ptr[reg_ptr_global]);
227
- vaddps (ymm_dst, ymm_dst, ymm_tmp);
228
- // build 2^n
229
- ymm_t ymm_int = ymm_fx;
230
- vcvttps2dq (ymm_int, ymm_fx);
231
- mov (reg_ptr_global, reinterpret_cast <size_t >(exp_int_0x7f));
232
- vmovdqa (ymm_tmp, ptr[reg_ptr_global]);
233
- if (MayIUse (avx2)) {
234
- vpaddd (ymm_int, ymm_int, ymm_tmp);
235
- vpslld (ymm_int, ymm_int, 23 );
236
- } else if (MayIUse (avx)) {
237
- xmm_t xtmp1 = xmm_t (ymm_int.getIdx ());
238
- xmm_t xtmp2 = xmm_t (ymm_tmp.getIdx ());
239
- reg64_t reg_ptr_tmp = reg_ptr_global;
240
- mov (reg_ptr_tmp, reinterpret_cast <size_t >(g_tmp_mem));
241
- vmovdqa (ptr[reg_ptr_tmp], ymm_int);
242
- vmovdqa (ptr[reg_ptr_tmp + YMM_FLOAT_BLOCK * sizeof (float )], ymm_tmp);
243
- vpaddd (xtmp1, xtmp1, xtmp2);
244
- vpslld (xtmp1, xtmp1, 23 );
245
- vmovdqa (ptr[reg_ptr_tmp], xtmp1);
246
- // next 128bits
247
- vmovdqa (xtmp1, ptr[reg_ptr_tmp + 4 /* xmm float block*/ * sizeof (float )]);
248
- vmovdqa (xtmp2,
249
- ptr[reg_ptr_tmp +
250
- (YMM_FLOAT_BLOCK + 4 /* xmm float block*/ ) * sizeof (float )]);
251
- vpaddd (xtmp1, xtmp1, xtmp2);
252
- vpslld (xtmp1, xtmp1, 23 );
253
- vmovdqa (ptr[reg_ptr_tmp + 4 /* xmm float block*/ * sizeof (float )], xtmp1);
254
- // load out
255
- vmovdqa (ymm_int, ptr[reg_ptr_tmp]);
256
- }
257
- vmulps (ymm_dst, ymm_dst, ymm_int);
258
- pop (reg_ptr_global);
259
- }
260
-
261
- void VActJitCode::exp_xmm (xmm_t & ymm_dst, xmm_t & ymm_src, int fx_idx,
262
- int fy_idx, int mask_idx, int tmp_idx) {
263
- assert (ymm_src.getIdx () != ymm_dst.getIdx ()); // TODO(TJ): use enfore
264
- // check all idx can not equal
265
- xmm_t ymm_fx = xmm_t (fx_idx);
266
- xmm_t ymm_fy = xmm_t (fy_idx);
267
- xmm_t ymm_mask = xmm_t (mask_idx);
268
- xmm_t ymm_tmp = xmm_t (tmp_idx);
269
- reg64_t reg_ptr_global = rax;
270
- push (reg_ptr_global);
271
- mov (reg_ptr_global, reinterpret_cast <size_t >(exp_float_consts));
272
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
273
- vminps (ymm_src, ymm_src, ymm_tmp);
274
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]);
275
- vmaxps (ymm_src, ymm_src, ymm_tmp);
276
- // express exp(x) as exp(g + n*log(2))
277
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]);
278
- vmulps (ymm_fx, ymm_src, ymm_tmp);
279
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]);
280
- vaddps (ymm_fx, ymm_fx, ymm_tmp);
281
- vroundps (ymm_fy, ymm_fx, 0x01 );
282
- // if greater, substract 1
283
- vcmpgtps (ymm_mask, ymm_fy, ymm_fx);
284
- vmovaps (ymm_tmp, ptr[reg_ptr_global]);
285
- vandps (ymm_mask, ymm_mask, ymm_tmp);
286
- vsubps (ymm_fx, ymm_fy, ymm_mask);
287
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]);
288
- vmulps (ymm_fy, ymm_fx, ymm_tmp);
289
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
290
- xmm_t ymm_z = xmm_t (ymm_mask.getIdx ());
291
- vmulps (ymm_z, ymm_fx, ymm_tmp);
292
- vsubps (ymm_src, ymm_src, ymm_fy);
293
- vsubps (ymm_src, ymm_src, ymm_z);
294
- vmulps (ymm_z, ymm_src, ymm_src);
295
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
296
- vmulps (ymm_dst, ymm_src, ymm_tmp);
297
- for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
298
- i += (YMM_FLOAT_BLOCK * sizeof (float ))) {
299
- vmovaps (ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4
300
- vaddps (ymm_dst, ymm_dst, ymm_tmp);
301
- vmulps (ymm_dst, ymm_dst, ymm_src);
302
- }
303
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]);
304
- vaddps (ymm_dst, ymm_dst, ymm_tmp);
305
- vmulps (ymm_dst, ymm_dst, ymm_z);
306
- vaddps (ymm_dst, ymm_dst, ymm_src);
307
- vmovaps (ymm_tmp, ptr[reg_ptr_global]);
308
- vaddps (ymm_dst, ymm_dst, ymm_tmp);
309
- // build 2^n
310
- xmm_t ymm_int = ymm_fx;
311
- vcvttps2dq (ymm_int, ymm_fx);
312
- mov (reg_ptr_global, reinterpret_cast <size_t >(exp_int_0x7f));
313
- vmovdqa (ymm_tmp, ptr[reg_ptr_global]);
314
- vpaddd (ymm_int, ymm_int, ymm_tmp);
315
- vpslld (ymm_int, ymm_int, 23 );
316
- vmulps (ymm_dst, ymm_dst, ymm_int);
317
- pop (reg_ptr_global);
318
- }
319
-
320
145
void VActJitCode::sigmoid_ymm (ymm_t & ymm_dst, ymm_t & ymm_src, int fx_idx,
321
146
int fy_idx, int mask_idx, int tmp_idx) {
322
147
// y = 1 / (1 + e^-x)
@@ -330,7 +155,7 @@ void VActJitCode::sigmoid_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
330
155
vmaxps (ymm_src, ymm_src, ymm_tmp);
331
156
vxorps (ymm_tmp, ymm_tmp, ymm_tmp);
332
157
vsubps (ymm_src, ymm_tmp, ymm_src);
333
- exp_ymm (ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
158
+ exp_jmm< ymm_t > (ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
334
159
vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
335
160
vaddps (ymm_dst, ymm_dst, ymm_tmp);
336
161
vdivps (ymm_dst, ymm_tmp, ymm_dst);
@@ -349,7 +174,7 @@ void VActJitCode::tanh_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
349
174
vxorps (ymm_zero, ymm_zero, ymm_zero);
350
175
vsubps (ymm_tmp, ymm_zero, ymm_tmp);
351
176
vmulps (ymm_src, ymm_src, ymm_tmp);
352
- exp_ymm (ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
177
+ exp_jmm< ymm_t > (ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
353
178
vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
354
179
vaddps (ymm_dst, ymm_dst, ymm_tmp);
355
180
vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
@@ -373,7 +198,7 @@ void VActJitCode::generate() {
373
198
relu_jmm<ymm_t >(ymm_dst, ymm_src, ymm_zero);
374
199
break ;
375
200
case operand_type::exp:
376
- exp_ymm (ymm_dst, ymm_src, 2 , 3 , 4 , 5 );
201
+ exp_jmm< ymm_t > (ymm_dst, ymm_src, 2 , 3 , 4 , 5 );
377
202
break ;
378
203
case operand_type::sigmoid:
379
204
sigmoid_ymm (ymm_dst, ymm_src, 2 , 3 , 4 , 5 );
@@ -409,7 +234,7 @@ void VActJitCode::generate() {
409
234
relu_jmm<xmm_t >(xmm_dst, xmm_src, xmm_zero);
410
235
break ;
411
236
case operand_type::exp:
412
- exp_xmm (xmm_dst, xmm_src, 2 , 3 , 4 , 5 );
237
+ exp_jmm< xmm_t > (xmm_dst, xmm_src, 2 , 3 , 4 , 5 );
413
238
break ;
414
239
default :
415
240
break ;
0 commit comments