@@ -156,17 +156,16 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
156
156
} \
157
157
}
158
158
159
- #define INTRIAVX2_FLOAT (block ) \
159
+ #define INTRIAVX2_FLOAT (isa, block ) \
160
160
template <> \
161
- CRFDecodeKernelImpl<float , jit::avx2, block>::CRFDecodeKernelImpl( \
162
- int tag_num) \
161
+ CRFDecodeKernelImpl<float , isa, block>::CRFDecodeKernelImpl(int tag_num) \
163
162
: CRFDecodeKernel<float >() { \
164
163
this ->num_ = tag_num; \
165
164
this ->end_ = this ->num_ / AVX2_FLOAT_BLOCK; \
166
165
this ->rest_ = this ->num_ % AVX2_FLOAT_BLOCK; \
167
166
} \
168
167
template <> \
169
- void CRFDecodeKernelImpl<float , jit::avx2 , block>::Compute( \
168
+ void CRFDecodeKernelImpl<float , isa , block>::Compute( \
170
169
const int seq_len, const float * x, const float * w, float * alpha, \
171
170
int * track) const { \
172
171
INIT_ALPHA (AVX2_FLOAT_BLOCK) \
@@ -224,7 +223,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
224
223
int j_offset = 0 ; \
225
224
for (int j = 0 ; j <= this ->end_ ; ++j) { \
226
225
/* Initialize the variables of maximum score and location.*/ \
227
- __m512 max_score = _mm512_set1_ps (-std::numeric_limits<T >::max ()); \
226
+ __m512 max_score = _mm512_set1_ps (-std::numeric_limits<float >::max ()); \
228
227
__m512i max_j = _mm512_setzero_si512 (); \
229
228
/* Calculate the offset of transition_weights.*/ \
230
229
int trans_offset = state_trans_base_idx * this ->num_ + j_offset; \
@@ -245,7 +244,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
245
244
__m512 x_content = \
246
245
_mm512_loadu_ps (x + seq_offset + this ->num_ + j_offset); \
247
246
max_score = _mm512_add_ps (max_score, x_content); \
248
- _mm512_storeu_ps (alpha_value + seq_offset + this ->tag_num_ + j_offset, \
247
+ _mm512_storeu_ps (alpha + seq_offset + this ->num_ + j_offset, \
249
248
max_score); \
250
249
_mm512_storeu_si512 (reinterpret_cast <__m512i*>(track + seq_offset + \
251
250
this ->num_ + j_offset), \
@@ -271,14 +270,14 @@ INTRIAVX_FLOAT(kEQ16);
271
270
INTRIAVX_FLOAT (kGT16 );
272
271
#endif
273
272
#ifdef __AVX2__
274
- INTRIAVX2_FLOAT (kEQ8 );
275
- INTRIAVX2_FLOAT (kGT8LT16 );
276
- INTRIAVX2_FLOAT (kEQ16 );
277
- INTRIAVX2_FLOAT (kGT16 );
273
+ INTRIAVX2_FLOAT (jit::avx2, kEQ8 );
274
+ INTRIAVX2_FLOAT (jit::avx2, kGT8LT16 );
275
+ INTRIAVX2_FLOAT (jit::avx2, kEQ16 );
276
+ INTRIAVX2_FLOAT (jit::avx2, kGT16 );
278
277
#endif
279
278
#ifdef __AVX512F__
280
- INTRIAVX2_FLOAT (kEQ8 );
281
- INTRIAVX2_FLOAT (kGT8LT16 );
279
+ INTRIAVX2_FLOAT (jit::avx512f, kEQ8 );
280
+ INTRIAVX2_FLOAT (jit::avx512f, kGT8LT16 );
282
281
INTRIAVX512_FLOAT (kEQ16 );
283
282
INTRIAVX512_FLOAT (kGT16 );
284
283
#endif
0 commit comments