@@ -197,10 +197,8 @@ static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
197
197
static int g_tmp_mem[16 ] ALIGN32 = {0 };
198
198
199
199
void VExpJitCode::generate () {
200
- preCode ();
201
- // push some?
202
200
// in: ymm0, out: ymm1
203
- // use ymm 0~5 (and ymm 14~15 if avx only)
201
+ // use ymm 0~5, rax
204
202
int offset = 0 ;
205
203
vmovups (ymm_src, ptr[param1 + offset]);
206
204
mov (reg_ptr_global, reinterpret_cast <size_t >(exp_float_consts));
@@ -222,7 +220,8 @@ void VExpJitCode::generate() {
222
220
vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]);
223
221
vmulps (ymm_fy, ymm_fx, ymm_tmp);
224
222
vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
225
- vmulps (ymm_z, ymm_fx, ymm_tmp); // ymm_z use same with mask
223
+ ymm_t ymm_z = ymm_t (ymm_mask.getIdx ());
224
+ vmulps (ymm_z, ymm_fx, ymm_tmp);
226
225
vsubps (ymm_src, ymm_src, ymm_fy);
227
226
vsubps (ymm_src, ymm_src, ymm_z);
228
227
vmulps (ymm_z, ymm_src, ymm_src);
@@ -240,7 +239,6 @@ void VExpJitCode::generate() {
240
239
vaddps (ymm_dst, ymm_dst, ymm_src);
241
240
vmovaps (ymm_tmp, ptr[reg_ptr_global]);
242
241
vaddps (ymm_dst, ymm_dst, ymm_tmp);
243
-
244
242
// build 2^n
245
243
ymm_t ymm_int = ymm_fx;
246
244
vcvttps2dq (ymm_int, ymm_fx);
@@ -250,31 +248,30 @@ void VExpJitCode::generate() {
250
248
vpaddd (ymm_int, ymm_int, ymm_tmp);
251
249
vpslld (ymm_int, ymm_int, 23 );
252
250
} else if (MayIUse (avx)) {
253
- // use ymm_int, ymm_tmp and reg_ptr_global
254
- xmm_t xtmp1 = xmm_t (ymm_int); // or magic number should equal the ymm_int
255
- xmm_t xtmp2 = xmm_t (ymm_tmp); // or magic number should equal the ymm_tmp
256
- mov (reg_ptr_global , reinterpret_cast <size_t >(g_tmp_mem));
257
- vmovdqa (ptr[reg_ptr_global ], ymm_int);
258
- vmovdqa (ptr[reg_ptr_global + AVX_FLOAT_BLOCK * sizeof (float )], ymm_tmp);
251
+ xmm_t xtmp1 = xmm_t (ymm_int. getIdx ());
252
+ xmm_t xtmp2 = xmm_t (ymm_tmp. getIdx ());
253
+ reg64_t reg_ptr_tmp = reg_ptr_global;
254
+ mov (reg_ptr_tmp , reinterpret_cast <size_t >(g_tmp_mem));
255
+ vmovdqa (ptr[reg_ptr_tmp ], ymm_int);
256
+ vmovdqa (ptr[reg_ptr_tmp + AVX_FLOAT_BLOCK * sizeof (float )], ymm_tmp);
259
257
vpaddd (xtmp1, xtmp1, xtmp2);
260
258
vpslld (xtmp1, xtmp1, 23 );
261
- vmovdqa (ptr[reg_ptr_global ], xtmp1);
259
+ vmovdqa (ptr[reg_ptr_tmp ], xtmp1);
262
260
// next 128bits
263
- vmovdqa (xtmp1, ptr[reg_ptr_global + 4 /* xmm float block*/ * sizeof (float )]);
261
+ vmovdqa (xtmp1, ptr[reg_ptr_tmp + 4 /* xmm float block*/ * sizeof (float )]);
264
262
vmovdqa (xtmp2,
265
- ptr[reg_ptr_global +
263
+ ptr[reg_ptr_tmp +
266
264
(AVX_FLOAT_BLOCK + 4 /* xmm float block*/ ) * sizeof (float )]);
267
265
vpaddd (xtmp1, xtmp1, xtmp2);
268
266
vpslld (xtmp1, xtmp1, 23 );
269
- vmovdqa (ptr[reg_ptr_global + 4 /* xmm float block*/ * sizeof (float )], xtmp1);
267
+ vmovdqa (ptr[reg_ptr_tmp + 4 /* xmm float block*/ * sizeof (float )], xtmp1);
270
268
// load out
271
- vmovdqa (ymm_int, ptr[reg_ptr_global ]);
269
+ vmovdqa (ymm_int, ptr[reg_ptr_tmp ]);
272
270
}
273
271
vmulps (ymm_dst, ymm_dst, ymm_int);
274
272
vmovups (ptr[param2 + offset], ymm_dst);
275
273
276
- // ret();
277
- postCode ();
274
+ ret ();
278
275
}
279
276
280
277
} // namespace gen
0 commit comments