@@ -132,56 +132,8 @@ const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
132
132
int g_tmp_mem[16 ] ALIGN32 = {0 };
133
133
134
134
bool VActJitCode::init (int d, operand_type type) {
135
- bool ok = MayIUse (avx);
136
- if (type == operand_type::relu || type == operand_type::exp) {
137
- // TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256
138
- return ok;
139
- } else {
140
- // TODO(TJ): support more
141
- return ok && d % 8 == 0 ;
142
- }
143
- }
144
-
145
- void VActJitCode::sigmoid_ymm (ymm_t & ymm_dst, ymm_t & ymm_src, int fx_idx,
146
- int fy_idx, int mask_idx, int tmp_idx) {
147
- // y = 1 / (1 + e^-x)
148
- ymm_t ymm_tmp = ymm_t (tmp_idx);
149
- reg64_t reg_ptr_global = rax;
150
- push (reg_ptr_global);
151
- mov (reg_ptr_global, reinterpret_cast <size_t >(exp_float_consts));
152
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]);
153
- vminps (ymm_src, ymm_src, ymm_tmp);
154
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]);
155
- vmaxps (ymm_src, ymm_src, ymm_tmp);
156
- vxorps (ymm_tmp, ymm_tmp, ymm_tmp);
157
- vsubps (ymm_src, ymm_tmp, ymm_src);
158
- exp_jmm<ymm_t >(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
159
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
160
- vaddps (ymm_dst, ymm_dst, ymm_tmp);
161
- vdivps (ymm_dst, ymm_tmp, ymm_dst);
162
- pop (reg_ptr_global);
163
- }
164
-
165
- void VActJitCode::tanh_ymm (ymm_t & ymm_dst, ymm_t & ymm_src, int fx_idx,
166
- int fy_idx, int mask_idx, int tmp_idx) {
167
- // y = 2 / (1 + e^(-2x)) - 1
168
- ymm_t ymm_tmp = ymm_t (tmp_idx);
169
- ymm_t ymm_zero = ymm_t (mask_idx);
170
- reg64_t reg_ptr_global = rax;
171
- push (reg_ptr_global);
172
- mov (reg_ptr_global, reinterpret_cast <size_t >(exp_float_consts));
173
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
174
- vxorps (ymm_zero, ymm_zero, ymm_zero);
175
- vsubps (ymm_tmp, ymm_zero, ymm_tmp);
176
- vmulps (ymm_src, ymm_src, ymm_tmp);
177
- exp_jmm<ymm_t >(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
178
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
179
- vaddps (ymm_dst, ymm_dst, ymm_tmp);
180
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
181
- vdivps (ymm_dst, ymm_tmp, ymm_dst);
182
- vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
183
- vsubps (ymm_dst, ymm_dst, ymm_tmp);
184
- pop (reg_ptr_global);
135
+ // TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256
136
+ return MayIUse (avx);
185
137
}
186
138
187
139
void VActJitCode::generate () {
@@ -201,10 +153,10 @@ void VActJitCode::generate() {
201
153
exp_jmm<ymm_t >(ymm_dst, ymm_src, 2 , 3 , 4 , 5 );
202
154
break ;
203
155
case operand_type::sigmoid:
204
- sigmoid_ymm (ymm_dst, ymm_src, 2 , 3 , 4 , 5 );
156
+ sigmoid_jmm< ymm_t > (ymm_dst, ymm_src, 2 , 3 , 4 , 5 );
205
157
break ;
206
158
case operand_type::tanh:
207
- tanh_ymm (ymm_dst, ymm_src, 2 , 3 , 4 , 5 );
159
+ tanh_jmm< ymm_t > (ymm_dst, ymm_src, 2 , 3 , 4 , 5 );
208
160
break ;
209
161
case operand_type::identity:
210
162
break ;
@@ -214,11 +166,6 @@ void VActJitCode::generate() {
214
166
vmovups (ptr[param2 + offset], ymm_dst);
215
167
offset += sizeof (float ) * YMM_FLOAT_BLOCK;
216
168
}
217
- if (type_ != operand_type::relu && type_ != operand_type::exp) {
218
- // TODO(TJ): remove me
219
- ret ();
220
- return ;
221
- }
222
169
int rest = num_ % YMM_FLOAT_BLOCK;
223
170
int block = XMM_FLOAT_BLOCK;
224
171
while (rest > 0 ) {
@@ -236,6 +183,12 @@ void VActJitCode::generate() {
236
183
case operand_type::exp:
237
184
exp_jmm<xmm_t >(xmm_dst, xmm_src, 2 , 3 , 4 , 5 );
238
185
break ;
186
+ case operand_type::sigmoid:
187
+ sigmoid_jmm<xmm_t >(xmm_dst, xmm_src, 2 , 3 , 4 , 5 );
188
+ break ;
189
+ case operand_type::tanh:
190
+ tanh_jmm<xmm_t >(xmm_dst, xmm_src, 2 , 3 , 4 , 5 );
191
+ break ;
239
192
default :
240
193
break ;
241
194
}
0 commit comments