@@ -118,40 +118,6 @@ void VXXJitCode::generate() {
118
118
ret ();
119
119
}
120
120
121
- bool ReluJitCode::init (int d) { return MayIUse (avx); }
122
-
123
- void ReluJitCode::generate () {
124
- int offset = 0 ;
125
- vxorps (ymm_zero, ymm_zero, ymm_zero);
126
- for (int i = 0 ; i < num_ / AVX_FLOAT_BLOCK; ++i) {
127
- vmovups (ymm_src, ptr[param1 + offset]);
128
- vmaxps (ymm_dst, ymm_zero, ymm_src);
129
- vmovups (ptr[param2 + offset], ymm_dst);
130
- offset += sizeof (float ) * AVX_FLOAT_BLOCK;
131
- }
132
- int rest = num_ % AVX_FLOAT_BLOCK;
133
- if (rest >= 4 ) {
134
- vmovups (xmm_src, ptr[param1 + offset]);
135
- vmaxps (xmm_dst, xmm_zero, xmm_src);
136
- vmovups (ptr[param2 + offset], xmm_dst);
137
- offset += sizeof (float ) * 4 ;
138
- rest -= 4 ;
139
- }
140
- if (rest >= 2 ) {
141
- vmovups (xmm_src, ptr[param1 + offset]);
142
- vmaxps (xmm_dst, xmm_zero, xmm_src);
143
- vmovq (ptr[param2 + offset], xmm_dst);
144
- offset += sizeof (float ) * 2 ;
145
- rest -= 2 ;
146
- }
147
- if (rest > 0 ) {
148
- vmovups (xmm_src, ptr[param1 + offset]);
149
- vmaxps (xmm_dst, xmm_zero, xmm_src);
150
- vmovss (ptr[param2 + offset], xmm_dst);
151
- }
152
- ret ();
153
- }
154
-
155
121
#define ALIGN32 __attribute__ ((aligned(32 )))
156
122
#define EXP_HIG 88 .3762626647949f
157
123
#define EXP_LOW -88 .3762626647949f
@@ -207,18 +173,28 @@ static const float exp_float_consts[] ALIGN32 = {
207
173
static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES (0x7f )};
208
174
static int g_tmp_mem[16 ] ALIGN32 = {0 };
209
175
210
- bool VExpJitCode::init (int d) {
211
- return MayIUse (avx) && d == 8 ; // only 8 yet
176
+ bool VActJitCode::init (int d, operand_type type) {
177
+ bool ok = MayIUse (avx);
178
+ if (type == operand_type::relu) {
179
+ return ok;
180
+ } else {
181
+ return ok && d == 8 ; // only 8 yet
182
+ }
212
183
}
213
184
214
- void VExpJitCode::exp_ymm (ymm_t & ymm_src, ymm_t & ymm_dst) {
215
- // use reg rax and ymm 2~5
216
- reg64_t reg_ptr_global = rax;
217
- ymm_t ymm_fx = ymm_t (2 );
218
- ymm_t ymm_fy = ymm_t (3 );
219
- ymm_t ymm_mask = ymm_t (4 );
220
- ymm_t ymm_tmp = ymm_t (5 );
185
+ void VActJitCode::relu_ymm (ymm_t & ymm_dst, ymm_t & ymm_src, ymm_t & ymm_zero) {
186
+ vmaxps (ymm_dst, ymm_zero, ymm_src);
187
+ }
188
+
189
+ void VActJitCode::exp_ymm (ymm_t & ymm_dst, ymm_t & ymm_src, int fx_idx,
190
+ int fy_idx, int mask_idx, int tmp_idx) {
221
191
assert (ymm_src.getIdx () != ymm_dst.getIdx ()); // TODO(TJ): use enfore
192
+ // check all idx can not equal
193
+ ymm_t ymm_fx = ymm_t (fx_idx);
194
+ ymm_t ymm_fy = ymm_t (fy_idx);
195
+ ymm_t ymm_mask = ymm_t (mask_idx);
196
+ ymm_t ymm_tmp = ymm_t (tmp_idx);
197
+ reg64_t reg_ptr_global = rax;
222
198
push (reg_ptr_global);
223
199
mov (reg_ptr_global, reinterpret_cast <size_t >(exp_float_consts));
224
200
vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
@@ -291,22 +267,11 @@ void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
291
267
pop (reg_ptr_global);
292
268
}
293
269
294
- void VExpJitCode::generate () {
295
- int offset = 0 ;
296
- vmovups (ymm_src, ptr[param1 + offset]);
297
- exp_ymm (ymm_src, ymm_dst);
298
- vmovups (ptr[param2 + offset], ymm_dst);
299
- ret ();
300
- }
301
-
302
- bool VSigmoidJitCode::init (int d) {
303
- return MayIUse (avx) && d == 8 ; // only 8 yet
304
- }
305
-
306
- void VSigmoidJitCode::sigmoid_ymm (ymm_t & ymm_src, ymm_t & ymm_dst) {
307
- // use ymm2
270
+ void VActJitCode::sigmoid_ymm (ymm_t & ymm_dst, ymm_t & ymm_src, int fx_idx,
271
+ int fy_idx, int mask_idx, int tmp_idx) {
272
+ // y = 1 / (1 + e^-x)
273
+ ymm_t ymm_tmp = ymm_t (tmp_idx);
308
274
reg64_t reg_ptr_global = rax;
309
- ymm_t ymm_tmp = ymm_t (2 );
310
275
push (reg_ptr_global);
311
276
mov (reg_ptr_global, reinterpret_cast <size_t >(exp_float_consts));
312
277
vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]);
@@ -315,38 +280,26 @@ void VSigmoidJitCode::sigmoid_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
315
280
vmaxps (ymm_src, ymm_src, ymm_tmp);
316
281
vxorps (ymm_tmp, ymm_tmp, ymm_tmp);
317
282
vsubps (ymm_src, ymm_tmp, ymm_src);
318
- exp_ymm (ymm_src, ymm_dst );
283
+ exp_ymm (ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx );
319
284
vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
320
285
vaddps (ymm_dst, ymm_dst, ymm_tmp);
321
286
vdivps (ymm_dst, ymm_tmp, ymm_dst);
322
287
pop (reg_ptr_global);
323
288
}
324
289
325
- void VSigmoidJitCode::generate () {
326
- int offset = 0 ;
327
- vmovups (ymm_src, ptr[param1 + offset]);
328
- sigmoid_ymm (ymm_src, ymm_dst);
329
- vmovups (ptr[param2 + offset], ymm_dst);
330
- ret ();
331
- }
332
-
333
- bool VTanhJitCode::init (int d) {
334
- return MayIUse (avx) && d == 8 ; // only 8 yet
335
- }
336
-
337
- void VTanhJitCode::vtanh_ymm (ymm_t & ymm_src, ymm_t & ymm_dst) {
290
+ void VActJitCode::tanh_ymm (ymm_t & ymm_dst, ymm_t & ymm_src, int fx_idx,
291
+ int fy_idx, int mask_idx, int tmp_idx) {
338
292
// y = 2 / (1 + e^(-2x)) - 1
339
- // use ymm2, ymm3
293
+ ymm_t ymm_tmp = ymm_t (tmp_idx);
294
+ ymm_t ymm_zero = ymm_t (mask_idx);
340
295
reg64_t reg_ptr_global = rax;
341
- ymm_t ymm_tmp = ymm_t (2 );
342
- ymm_t ymm_zero = ymm_t (3 );
343
296
push (reg_ptr_global);
344
297
mov (reg_ptr_global, reinterpret_cast <size_t >(exp_float_consts));
345
298
vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
346
299
vxorps (ymm_zero, ymm_zero, ymm_zero);
347
300
vsubps (ymm_tmp, ymm_zero, ymm_tmp);
348
301
vmulps (ymm_src, ymm_src, ymm_tmp);
349
- exp_ymm (ymm_src, ymm_dst );
302
+ exp_ymm (ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx );
350
303
vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
351
304
vaddps (ymm_dst, ymm_dst, ymm_tmp);
352
305
vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
@@ -356,11 +309,61 @@ void VTanhJitCode::vtanh_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
356
309
pop (reg_ptr_global);
357
310
}
358
311
359
- void VTanhJitCode::generate () {
312
+ void VActJitCode::generate () {
313
+ xmm_t xmm_zero = xmm_t (2 );
314
+ ymm_t ymm_zero = ymm_t (2 );
315
+ if (type_ == operand_type::relu) {
316
+ vxorps (ymm_zero, ymm_zero, ymm_zero);
317
+ }
360
318
int offset = 0 ;
361
- vmovups (ymm_src, ptr[param1 + offset]);
362
- vtanh_ymm (ymm_src, ymm_dst);
363
- vmovups (ptr[param2 + offset], ymm_dst);
319
+ for (int i = 0 ; i < num_ / AVX_FLOAT_BLOCK; ++i) {
320
+ vmovups (ymm_src, ptr[param1 + offset]);
321
+ switch (type_) {
322
+ case operand_type::relu:
323
+ relu_ymm (ymm_dst, ymm_src, ymm_zero);
324
+ break ;
325
+ case operand_type::exp:
326
+ exp_ymm (ymm_dst, ymm_src, 2 , 3 , 4 , 5 );
327
+ break ;
328
+ case operand_type::sigmoid:
329
+ sigmoid_ymm (ymm_dst, ymm_src, 2 , 3 , 4 , 5 );
330
+ break ;
331
+ case operand_type::tanh:
332
+ tanh_ymm (ymm_dst, ymm_src, 2 , 3 , 4 , 5 );
333
+ break ;
334
+ case operand_type::identity:
335
+ break ;
336
+ default :
337
+ break ;
338
+ }
339
+ vmovups (ptr[param2 + offset], ymm_dst);
340
+ offset += sizeof (float ) * AVX_FLOAT_BLOCK;
341
+ }
342
+ if (type_ != operand_type::relu) {
343
+ // TODO(TJ): remove me
344
+ ret ();
345
+ return ;
346
+ }
347
+ int rest = num_ % AVX_FLOAT_BLOCK;
348
+ if (rest >= 4 ) {
349
+ vmovups (xmm_src, ptr[param1 + offset]);
350
+ vmaxps (xmm_dst, xmm_zero, xmm_src);
351
+ vmovups (ptr[param2 + offset], xmm_dst);
352
+ offset += sizeof (float ) * 4 ;
353
+ rest -= 4 ;
354
+ }
355
+ if (rest >= 2 ) {
356
+ vmovups (xmm_src, ptr[param1 + offset]);
357
+ vmaxps (xmm_dst, xmm_zero, xmm_src);
358
+ vmovq (ptr[param2 + offset], xmm_dst);
359
+ offset += sizeof (float ) * 2 ;
360
+ rest -= 2 ;
361
+ }
362
+ if (rest > 0 ) {
363
+ vmovups (xmm_src, ptr[param1 + offset]);
364
+ vmaxps (xmm_dst, xmm_zero, xmm_src);
365
+ vmovss (ptr[param2 + offset], xmm_dst);
366
+ }
364
367
ret ();
365
368
}
366
369
0 commit comments