From baa6dc1d0e45192d47a3214e1479cd8dcd7e57c2 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 10 Sep 2024 16:16:29 -0700 Subject: [PATCH] [ExecuTorch] Build optimized kernels with bf16 support and gate usage at runtime Differential Revision: [D62466496](https://our.internmc.facebook.com/intern/diff/D62466496/) [ghstack-poisoned] --- kernels/optimized/blas/BlasKernel.cpp | 62 +++++++++++++++------------ kernels/optimized/lib_defs.bzl | 11 +++++ kernels/test/op_linear_test.cpp | 8 ++-- 3 files changed, 49 insertions(+), 32 deletions(-) diff --git a/kernels/optimized/blas/BlasKernel.cpp b/kernels/optimized/blas/BlasKernel.cpp index 813e9e9efc7..335e4aa6a10 100644 --- a/kernels/optimized/blas/BlasKernel.cpp +++ b/kernels/optimized/blas/BlasKernel.cpp @@ -10,6 +10,7 @@ #ifdef __aarch64__ #include +#include #endif using torch::executor::BFloat16; @@ -73,39 +74,39 @@ f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) { return f32_fma(a, to_bfloat16(b), to_bfloat16(c)); } -#ifdef __ARM_FEATURE_BF16 -static ET_INLINE float32x4_t f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) { +static ET_INLINE float32x4_t +f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) { return vbfdotq_f32(a, b, c); } -#endif +template static ET_INLINE void dot_with_fp32_arith_main_inner_loop( const BFloat16* vec1, const BFloat16* vec2, float32x4_t sum[kF32RegistersPerIteration], int registerPairIndex) { - // TODO: detect intrinsic availability, use them if they're available. - // __ARM_FEATURE_BF16 Load a pair of f32 registers at a time. -#ifdef __ARM_FEATURE_BF16 - const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast(&vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); - const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast(&vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); - // TODO: we leave half of sum unused. Does this cause suboptimal code generation? - sum[registerPairIndex] = f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2); -#else - const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast( - &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); - const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast( - &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); - - sum[2 * registerPairIndex] = f32_fma_bf16( - sum[2 * registerPairIndex], - vget_low_u16(temp_vec1), - vget_low_u16(temp_vec2)); - sum[2 * registerPairIndex + 1] = f32_fma_bf16( - sum[2 * registerPairIndex + 1], - vget_high_u16(temp_vec1), - vget_high_u16(temp_vec2)); -#endif + if (useBfloat16Dot) { + const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast( + &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); + const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast( + &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); + sum[registerPairIndex] = + f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2); + } else { + const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast( + &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); + const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast( + &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); + + sum[2 * registerPairIndex] = f32_fma_bf16( + sum[2 * registerPairIndex], + vget_low_u16(temp_vec1), + vget_low_u16(temp_vec2)); + sum[2 * registerPairIndex + 1] = f32_fma_bf16( + sum[2 * registerPairIndex + 1], + vget_high_u16(temp_vec1), + vget_high_u16(temp_vec2)); + } } static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop( @@ -120,7 +121,7 @@ static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop( *tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2); } -template +template float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) { float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)}; const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); @@ -129,7 +130,8 @@ float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) { const auto* vec2_ = vec2 + j; utils::ForcedUnroll{}( [vec1_, vec2_, &sum](auto k) ET_INLINE_ATTRIBUTE { - dot_with_fp32_arith_main_inner_loop(vec1_, vec2_, sum, k); + dot_with_fp32_arith_main_inner_loop( + vec1_, vec2_, sum, k); }); } auto reducedSum = reduce(sum); @@ -156,7 +158,11 @@ float bf16_dot_with_fp32_arith( const BFloat16* vec1, const BFloat16* vec2, int64_t len) { - return dot_with_fp32_arith(vec1, vec2, len); + if (cpuinfo_has_arm_bf16()) { + return dot_with_fp32_arith(vec1, vec2, len); + } else { + return dot_with_fp32_arith(vec1, vec2, len); + } } #endif } // namespace internal diff --git a/kernels/optimized/lib_defs.bzl b/kernels/optimized/lib_defs.bzl index 23bfda9d5a6..7ef1cc7e101 100644 --- a/kernels/optimized/lib_defs.bzl +++ b/kernels/optimized/lib_defs.bzl @@ -129,6 +129,14 @@ def define_libs(): ] if not runtime.is_oss else [], "DEFAULT": [], }), + fbandroid_platform_compiler_flags = [ + ( + "^android-arm64.*$", + [ + "-march=armv8+bf16", + ], + ), + ], fbandroid_platform_preprocessor_flags = [ ( "^android-arm64.*$", @@ -145,6 +153,9 @@ def define_libs(): ], ), ], + fbobjc_compiler_flags = [ + "-march=armv8+bf16", + ], fbobjc_exported_preprocessor_flags = [ "-DET_BUILD_WITH_BLAS", "-DET_BUILD_FOR_APPLE", diff --git a/kernels/test/op_linear_test.cpp b/kernels/test/op_linear_test.cpp index 96875cc6f77..47f8925af08 100644 --- a/kernels/test/op_linear_test.cpp +++ b/kernels/test/op_linear_test.cpp @@ -43,16 +43,16 @@ class OpLinearOutTest : public OperatorTest { } } - // matmul gives 4 * 2 * 3 = 24 - Tensor x = tf.full({3, 4}, 2); - Tensor y = tf.full({5, 4}, 3); + // matmul gives 32 * 2 * 3 = 192 + Tensor x = tf.full({3, 32}, 2); + Tensor y = tf.full({5, 32}, 3); // Output shape should be (3, 5) Tensor out = tf.zeros({3, 5}); op_linear_out(x, y, out); - Tensor expected = tf.full({3, 5}, 24); + Tensor expected = tf.full({3, 5}, 192); EXPECT_TENSOR_EQ(out, expected); }