Skip to content

Commit 64d5b43

Browse files
committed
fix crf decode avx512
1 parent 21487d7 commit 64d5b43

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

paddle/fluid/operators/math/jit_kernel_crf_decode.cc

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -156,17 +156,16 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
156156
} \
157157
}
158158

159-
#define INTRIAVX2_FLOAT(block) \
159+
#define INTRIAVX2_FLOAT(isa, block) \
160160
template <> \
161-
CRFDecodeKernelImpl<float, jit::avx2, block>::CRFDecodeKernelImpl( \
162-
int tag_num) \
161+
CRFDecodeKernelImpl<float, isa, block>::CRFDecodeKernelImpl(int tag_num) \
163162
: CRFDecodeKernel<float>() { \
164163
this->num_ = tag_num; \
165164
this->end_ = this->num_ / AVX2_FLOAT_BLOCK; \
166165
this->rest_ = this->num_ % AVX2_FLOAT_BLOCK; \
167166
} \
168167
template <> \
169-
void CRFDecodeKernelImpl<float, jit::avx2, block>::Compute( \
168+
void CRFDecodeKernelImpl<float, isa, block>::Compute( \
170169
const int seq_len, const float* x, const float* w, float* alpha, \
171170
int* track) const { \
172171
INIT_ALPHA(AVX2_FLOAT_BLOCK) \
@@ -224,7 +223,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
224223
int j_offset = 0; \
225224
for (int j = 0; j <= this->end_; ++j) { \
226225
/* Initialize the variables of maximum score and location.*/ \
227-
__m512 max_score = _mm512_set1_ps(-std::numeric_limits<T>::max()); \
226+
__m512 max_score = _mm512_set1_ps(-std::numeric_limits<float>::max()); \
228227
__m512i max_j = _mm512_setzero_si512(); \
229228
/* Calculate the offset of transition_weights.*/ \
230229
int trans_offset = state_trans_base_idx * this->num_ + j_offset; \
@@ -245,7 +244,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
245244
__m512 x_content = \
246245
_mm512_loadu_ps(x + seq_offset + this->num_ + j_offset); \
247246
max_score = _mm512_add_ps(max_score, x_content); \
248-
_mm512_storeu_ps(alpha_value + seq_offset + this->tag_num_ + j_offset, \
247+
_mm512_storeu_ps(alpha + seq_offset + this->num_ + j_offset, \
249248
max_score); \
250249
_mm512_storeu_si512(reinterpret_cast<__m512i*>(track + seq_offset + \
251250
this->num_ + j_offset), \
@@ -271,14 +270,14 @@ INTRIAVX_FLOAT(kEQ16);
271270
INTRIAVX_FLOAT(kGT16);
272271
#endif
273272
#ifdef __AVX2__
274-
INTRIAVX2_FLOAT(kEQ8);
275-
INTRIAVX2_FLOAT(kGT8LT16);
276-
INTRIAVX2_FLOAT(kEQ16);
277-
INTRIAVX2_FLOAT(kGT16);
273+
INTRIAVX2_FLOAT(jit::avx2, kEQ8);
274+
INTRIAVX2_FLOAT(jit::avx2, kGT8LT16);
275+
INTRIAVX2_FLOAT(jit::avx2, kEQ16);
276+
INTRIAVX2_FLOAT(jit::avx2, kGT16);
278277
#endif
279278
#ifdef __AVX512F__
280-
INTRIAVX2_FLOAT(kEQ8);
281-
INTRIAVX2_FLOAT(kGT8LT16);
279+
INTRIAVX2_FLOAT(jit::avx512f, kEQ8);
280+
INTRIAVX2_FLOAT(jit::avx512f, kGT8LT16);
282281
INTRIAVX512_FLOAT(kEQ16);
283282
INTRIAVX512_FLOAT(kGT16);
284283
#endif

0 commit comments

Comments
 (0)