diff --git a/kernels/optimized/blas/BlasKernel.cpp b/kernels/optimized/blas/BlasKernel.cpp index 7202c8cd472..cfa362420f9 100644 --- a/kernels/optimized/blas/BlasKernel.cpp +++ b/kernels/optimized/blas/BlasKernel.cpp @@ -73,13 +73,26 @@ f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) { return f32_fma(a, to_bfloat16(b), to_bfloat16(c)); } +#ifdef __ARM_FEATURE_BF16 +static ET_INLINE float32x4_t +f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) { + return vbfdotq_f32(a, b, c); +} +#endif + static ET_INLINE void dot_with_fp32_arith_main_inner_loop( const BFloat16* vec1, const BFloat16* vec2, float32x4_t sum[kF32RegistersPerIteration], int registerPairIndex) { - // TODO: detect intrinsic availability, use them if they're available. - // __ARM_FEATURE_BF16 Load a pair of f32 registers at a time. +#ifdef __ARM_FEATURE_BF16 + const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast( + &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); + const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast( + &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); + sum[registerPairIndex] = + f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2); +#else const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast( &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast( @@ -93,6 +106,7 @@ static ET_INLINE void dot_with_fp32_arith_main_inner_loop( sum[2 * registerPairIndex + 1], vget_high_u16(temp_vec1), vget_high_u16(temp_vec2)); +#endif } static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(