diff --git a/kernels/optimized/blas/BlasKernel.cpp b/kernels/optimized/blas/BlasKernel.cpp index cfa362420f9..fada50c2d31 100644 --- a/kernels/optimized/blas/BlasKernel.cpp +++ b/kernels/optimized/blas/BlasKernel.cpp @@ -10,6 +10,7 @@ #ifdef __aarch64__ #include +#include #endif using torch::executor::BFloat16; @@ -80,32 +81,37 @@ f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) { } #endif +template static ET_INLINE void dot_with_fp32_arith_main_inner_loop( const BFloat16* vec1, const BFloat16* vec2, float32x4_t sum[kF32RegistersPerIteration], int registerPairIndex) { #ifdef __ARM_FEATURE_BF16 - const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast( - &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); - const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast( - &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); - sum[registerPairIndex] = - f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2); -#else - const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast( - &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); - const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast( - &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); - - sum[2 * registerPairIndex] = f32_fma_bf16( - sum[2 * registerPairIndex], - vget_low_u16(temp_vec1), - vget_low_u16(temp_vec2)); - sum[2 * registerPairIndex + 1] = f32_fma_bf16( - sum[2 * registerPairIndex + 1], - vget_high_u16(temp_vec1), - vget_high_u16(temp_vec2)); + if (useBfloat16Dot) { + const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast( + &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); + const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast( + &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); + sum[registerPairIndex] = + f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2); + } else { +#endif + const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast( + &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); + const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast( + &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); + + sum[2 * registerPairIndex] = f32_fma_bf16( + sum[2 * registerPairIndex], + vget_low_u16(temp_vec1), + vget_low_u16(temp_vec2)); + sum[2 * registerPairIndex + 1] = f32_fma_bf16( + sum[2 * registerPairIndex + 1], + vget_high_u16(temp_vec1), + vget_high_u16(temp_vec2)); +#ifdef __ARM_FEATURE_BF16 + } #endif } @@ -121,7 +127,7 @@ static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop( *tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2); } -template +template float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) { float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)}; const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); @@ -130,7 +136,8 @@ float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) { const auto* vec2_ = vec2 + j; utils::ForcedUnroll{}( [vec1_, vec2_, &sum](auto k) ET_INLINE_ATTRIBUTE { - dot_with_fp32_arith_main_inner_loop(vec1_, vec2_, sum, k); + dot_with_fp32_arith_main_inner_loop( + vec1_, vec2_, sum, k); }); } auto reducedSum = reduce(sum); @@ -157,7 +164,15 @@ float bf16_dot_with_fp32_arith( const BFloat16* vec1, const BFloat16* vec2, int64_t len) { - return dot_with_fp32_arith(vec1, vec2, len); +#ifdef __ARM_FEATURE_BF16 + if (cpuinfo_has_arm_bf16()) { + return dot_with_fp32_arith(vec1, vec2, len); + } else { +#endif + return dot_with_fp32_arith(vec1, vec2, len); +#ifdef __ARM_FEATURE_BF16 + } +#endif } #endif } // namespace internal diff --git a/kernels/optimized/lib_defs.bzl b/kernels/optimized/lib_defs.bzl index 04ee0cfde42..374895ecbbc 100644 --- a/kernels/optimized/lib_defs.bzl +++ b/kernels/optimized/lib_defs.bzl @@ -129,6 +129,14 @@ def define_libs(): ] if not runtime.is_oss else [], "DEFAULT": [], }), + fbandroid_platform_compiler_flags = [ + ( + "^android-arm64.*$", + [ + "-march=armv8+bf16", + ], + ), + ], fbandroid_platform_preprocessor_flags = [ ( "^android-arm64.*$", @@ -145,6 +153,9 @@ def define_libs(): ], ), ], + fbobjc_compiler_flags = [ + "-march=armv8+bf16", + ], fbobjc_exported_preprocessor_flags = [ "-DET_BUILD_WITH_BLAS", "-DET_BUILD_FOR_APPLE", diff --git a/kernels/test/op_linear_test.cpp b/kernels/test/op_linear_test.cpp index 96875cc6f77..47f8925af08 100644 --- a/kernels/test/op_linear_test.cpp +++ b/kernels/test/op_linear_test.cpp @@ -43,16 +43,16 @@ class OpLinearOutTest : public OperatorTest { } } - // matmul gives 4 * 2 * 3 = 24 - Tensor x = tf.full({3, 4}, 2); - Tensor y = tf.full({5, 4}, 3); + // matmul gives 32 * 2 * 3 = 192 + Tensor x = tf.full({3, 32}, 2); + Tensor y = tf.full({5, 32}, 3); // Output shape should be (3, 5) Tensor out = tf.zeros({3, 5}); op_linear_out(x, y, out); - Tensor expected = tf.full({3, 5}, 24); + Tensor expected = tf.full({3, 5}, 192); EXPECT_TENSOR_EQ(out, expected); } diff --git a/shim/xplat/executorch/build/env_interface.bzl b/shim/xplat/executorch/build/env_interface.bzl index 5b0acd36dab..b6e30cd9f65 100644 --- a/shim/xplat/executorch/build/env_interface.bzl +++ b/shim/xplat/executorch/build/env_interface.bzl @@ -118,7 +118,8 @@ def _remove_platform_specific_args(kwargs): """ keys = [] for key in kwargs: - if key.endswith("_platform_preprocessor_flags") or key.endswith("_platform_deps") or key.startswith("fbobjc"): + if (key.endswith("_platform_preprocessor_flags") or key.endswith("_platform_deps") or + key.startswith("fbobjc") or key.endswith("_platform_compiler_flags")): keys.append(key) for key in keys: kwargs.pop(key)