From b2d6138bf179a2ed148623e1d7820ef8506954b1 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 10 Sep 2024 16:16:20 -0700 Subject: [PATCH] [ExecuTorch] Use bfdot if compiled with ARM_FEATURE_BF16 Port of https://github.com/pytorch/pytorch/pull/127488 . Differential Revision: [D62159105](https://our.internmc.facebook.com/intern/diff/D62159105/) [ghstack-poisoned] --- kernels/optimized/blas/BlasKernel.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/kernels/optimized/blas/BlasKernel.cpp b/kernels/optimized/blas/BlasKernel.cpp index 7202c8cd472..813e9e9efc7 100644 --- a/kernels/optimized/blas/BlasKernel.cpp +++ b/kernels/optimized/blas/BlasKernel.cpp @@ -73,6 +73,12 @@ f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) { return f32_fma(a, to_bfloat16(b), to_bfloat16(c)); } +#ifdef __ARM_FEATURE_BF16 +static ET_INLINE float32x4_t f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) { + return vbfdotq_f32(a, b, c); +} +#endif + static ET_INLINE void dot_with_fp32_arith_main_inner_loop( const BFloat16* vec1, const BFloat16* vec2, @@ -80,6 +86,12 @@ static ET_INLINE void dot_with_fp32_arith_main_inner_loop( int registerPairIndex) { // TODO: detect intrinsic availability, use them if they're available. // __ARM_FEATURE_BF16 Load a pair of f32 registers at a time. +#ifdef __ARM_FEATURE_BF16 + const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast(&vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); + const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast(&vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); + // TODO: we leave half of sum unused. Does this cause suboptimal code generation? + sum[registerPairIndex] = f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2); +#else const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast( &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast( @@ -93,6 +105,7 @@ static ET_INLINE void dot_with_fp32_arith_main_inner_loop( sum[2 * registerPairIndex + 1], vget_high_u16(temp_vec1), vget_high_u16(temp_vec2)); +#endif } static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(