From 120c4baa8bb867da1bc0faf8a6276228118bcaaf Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 13 May 2025 21:41:03 -0700 Subject: [PATCH 1/5] Mostly sync BlasKernel.cpp with ATen ReducedPrecisionGemvFastPathKernel The two files were similar, but diverged due to recent changes. Since we have sharing of PyTorch headers, we can keep them mostly the same; differences are some of the namespace stuff and a couple of EXECUTORCH NOTEs. Differential Revision: [D74702689](https://our.internmc.facebook.com/intern/diff/D74702689/) [ghstack-poisoned] --- kernels/optimized/blas/BlasKernel.cpp | 399 ++++++++++++++++---------- kernels/optimized/lib_defs.bzl | 3 +- 2 files changed, 255 insertions(+), 147 deletions(-) diff --git a/kernels/optimized/blas/BlasKernel.cpp b/kernels/optimized/blas/BlasKernel.cpp index a3e2172504d..209b700a5e4 100644 --- a/kernels/optimized/blas/BlasKernel.cpp +++ b/kernels/optimized/blas/BlasKernel.cpp @@ -6,148 +6,225 @@ * LICENSE file in the root directory of this source tree. */ +// NOTE: This file is mostly the same as +// ReducedPrecisionFloatGemvFastPathKernel.cpp in PyTorch. Actually +// sharing the two versions is a TODO. #include +#include +#include +#include + +#include +#include +#include +#include #ifdef __aarch64__ #include #include #endif +namespace vec = at::vec; +using executorch::extension::parallel_for; using torch::executor::BFloat16; +using torch::executor::Half; -namespace executorch { -namespace cpublas { -namespace internal { -#ifdef __aarch64__ -static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) { -#ifdef __ARM_FEATURE_FMA - return vfmaq_f32(a, b, c); +namespace executorch::cpublas::internal { +constexpr auto kF32RegisterPairsPerIteration = 4; +constexpr auto kF32RegistersPerIteration = kF32RegisterPairsPerIteration * 2; +constexpr auto kF32ElementsPerRegister = vec::Vectorized::size(); +constexpr auto kF32ElementsPerIteration = kF32RegistersPerIteration * kF32ElementsPerRegister; + +namespace { +template +constexpr int IntegerLog2(T n, int p = 0) { + return (n <= 1) ? p : IntegerLog2(n / 2, p + 1); +} + +/* + * NOTE [ GGML Copyright Notice ] + * The below reduce overload and fp16_dot_with_fp16_arith function is + * adapted from llama.cpp's ggml_vec_dot_f16 and surrounding utility + * functions, so here is the required copyright notice: + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +float reduce(vec::Vectorized x) { +#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) + return vaddvq_f32(x); #else - return vaddq_f32(a, vmulq_f32(b, c)); -#endif // __ARM_FEATURE_FMA + return vec::vec_reduce_all( + std::plus>(), + x); +#endif } // The below reduce overload and fp16_dot_with_fp32_arith are adapted // from llama.cpp's ggml_vec_dot_f32 and surrounding utility // functions. See NOTE [ GGML Copyright Notice ] above for the // required notice. - -// We need the shift for reduce(), hence the extra constants. -static constexpr auto kF32ElementsPerIterationShift = 5; -static constexpr auto kF32ElementsPerIteration = 1 - << kF32ElementsPerIterationShift; -static_assert(kF32ElementsPerIteration == 32); - -static constexpr auto kF32ElementsPerRegisterShift = 2; -static constexpr auto kF32ElementsPerRegister = 1 - << kF32ElementsPerRegisterShift; -static_assert(kF32ElementsPerRegister == 4); - -static constexpr auto kF32RegisterPairsPerIteration = 4; -static constexpr auto kF32RegistersPerIteration = - kF32RegisterPairsPerIteration * 2; -static constexpr auto kF32RegistersPerIterationShift = 3; -static_assert( - kF32RegistersPerIteration == - kF32ElementsPerIteration / kF32ElementsPerRegister); -static_assert(kF32RegistersPerIteration == 1 << kF32RegistersPerIterationShift); - -static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) { +float reduce(vec::VectorizedN& x) { int offset = kF32RegistersPerIteration; - utils::ForcedUnroll{}( - [&offset, &x](auto idx) ET_INLINE_ATTRIBUTE { - offset /= 2; - for (int i = 0; i < offset; ++i) { - x[i] = vaddq_f32(x[i], x[offset + i]); - } - }); - return vaddvq_f32(x[0]); + c10::ForcedUnroll{}([&offset, &x](auto idx) { + offset /= 2; + for (const auto i : c10::irange(offset)) { + x[i] = x[i] + x[offset + i]; + } + }); + return reduce(x[0]); } -static ET_INLINE float32x4_t to_bfloat16(uint16x4_t u16) { - int32x4_t shift = vdupq_n_s32(16); - return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16), shift)); -} +// We would have to write a separate SVE-specific path to use SVE +// BFDOT. Deferring that for now to get the NEON/ASIMD BFDOT path +// working. +#if __ARM_FEATURE_BF16_VECTOR_ARITHMETIC +#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15 +// https://godbolt.org/z/z8P4Yncra +#define COMPILER_SUPPORTS_BF16_TARGET 1 +#elif defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 10 +// https://gcc.gnu.org/gcc-10/changes.html +// https://godbolt.org/z/cdGG7vn8o +#define COMPILER_SUPPORTS_BF16_TARGET 1 +#else // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15 +#define COMPILER_SUPPORTS_BF16_TARGET 0 +#endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15 +#else // __ARM_FEATURE_BF16_VECTOR_ARITHMETIC +#define COMPILER_SUPPORTS_BF16_TARGET 0 +#endif // __ARM_FEATURE_BF16_VECTOR_ARITHMETIC -static ET_INLINE float32x4_t -f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) { - return f32_fma(a, to_bfloat16(b), to_bfloat16(c)); -} +#if COMPILER_SUPPORTS_BF16_TARGET +#define TARGET_ARM_BF16_ATTRIBUTE __attribute__((target("arch=armv8.2-a+bf16"))) -#define ET_TARGET_ARM_BF16_ATTRIBUTE \ - __attribute__((target("arch=armv8.2-a+bf16"))) -ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE float32x4_t -f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) { - return vbfdotq_f32(a, b, c); -} - -ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void +TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop_bfdot( const BFloat16* vec1, const BFloat16* vec2, - float32x4_t sum[kF32RegistersPerIteration], + vec::VectorizedN& sum, int registerPairIndex) { - const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast( - &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); - const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast( - &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); + // NOTE[Intrinsics in bfdot variant]: We can't use + // vec::Vectorized::loadu here because linux-aarch64 GCC + // inexplicably can't convert Vectorized to + // bfloat16x8_t. I suspect a bug or incomplete + // __attribute__((target)) implementation. Intrinsics should be fine + // because we're using vbfdotq_f32 below anyway. + const auto temp_vec1 = vld1q_bf16( + reinterpret_cast( + &vec1[registerPairIndex * vec::Vectorized::size()])); + const auto temp_vec2 = vld1q_bf16( + reinterpret_cast( + &vec2[registerPairIndex * vec::Vectorized::size()])); sum[registerPairIndex] = - f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2); + vbfdotq_f32(sum[registerPairIndex], temp_vec1, temp_vec2); } -static ET_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot( - const BFloat16* vec1, - const BFloat16* vec2, - float32x4_t sum[kF32RegistersPerIteration], - int registerPairIndex) { - const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast( - &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); - const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast( - &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); - - sum[2 * registerPairIndex] = f32_fma_bf16( - sum[2 * registerPairIndex], - vget_low_u16(temp_vec1), - vget_low_u16(temp_vec2)); - sum[2 * registerPairIndex + 1] = f32_fma_bf16( - sum[2 * registerPairIndex + 1], - vget_high_u16(temp_vec1), - vget_high_u16(temp_vec2)); +TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE +void dot_with_fp32_arith_vectorized_tail_inner_loop_bfdot( + const at::BFloat16* vec1, + const at::BFloat16* vec2, + vec::Vectorized* tail_sum, + int idx) { + // See NOTE[Intrinsics in bfdot variant] above. + const auto temp_vec1 = vld1q_bf16(reinterpret_cast(&vec1[idx])); + const auto temp_vec2 = vld1q_bf16(reinterpret_cast(&vec2[idx])); + *tail_sum = vbfdotq_f32(*tail_sum, temp_vec1, temp_vec2); } -template -ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void -dot_with_fp32_arith_main_inner_loop( - const BFloat16* vec1, - const BFloat16* vec2, - float32x4_t sum[kF32RegistersPerIteration], - int registerPairIndex) { - if constexpr (useBfdot) { - dot_with_fp32_arith_main_inner_loop_bfdot( - vec1, vec2, sum, registerPairIndex); - } else { - dot_with_fp32_arith_main_inner_loop_no_bfdot( - vec1, vec2, sum, registerPairIndex); - } +#else +#define TARGET_ARM_BF16_ATTRIBUTE +#endif // COMPILER_SUPPORTS_BF16_TARGET + +namespace { + +[[maybe_unused]] std::pair, vec::Vectorized> fmadd( + const vec::Vectorized& a, + const vec::Vectorized& b, + const vec::Vectorized& acc_low, + const vec::Vectorized& acc_high) { + const auto [a_float_low, a_float_high] = convert_bfloat16_float(a); + const auto [b_float_low, b_float_high] = convert_bfloat16_float(b); + return std::make_pair(fmadd(a_float_low, b_float_low, acc_low), fmadd(a_float_high, b_float_high, acc_high)); } -static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop( - const BFloat16* vec1, - const BFloat16* vec2, - float32x4_t* tailSum, +[[maybe_unused]] vec::Vectorized fmadd( + const vec::Vectorized& acc, + const vec::Vectorized& a, + const vec::Vectorized& b) { + const auto [a_float_low, a_float_high] = convert_bfloat16_float(a); + const auto [b_float_low, b_float_high] = convert_bfloat16_float(b); + return fmadd(a_float_high, b_float_high, fmadd(a_float_low, b_float_low, acc)); +} +} // namespace + +template +C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot( + const T* vec1, + const T* vec2, + vec::VectorizedN& sum, + int registerPairIndex) { + static_assert(std::is_same_v); + const auto temp_vec1 = vec::Vectorized::loadu(&vec1[registerPairIndex * vec::Vectorized::size()]); + const auto temp_vec2 = vec::Vectorized::loadu(&vec2[registerPairIndex * vec::Vectorized::size()]); + + const auto [result_low, result_high] = fmadd(temp_vec1, temp_vec2, sum[2 * registerPairIndex], sum[2 * registerPairIndex + 1]); + sum[2 * registerPairIndex] = result_low; + sum[2 * registerPairIndex + 1] = result_high; +} + +template +C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop_no_bfdot( + const T* vec1, + const T* vec2, + vec::Vectorized* tail_sum, int idx) { - const auto temp_vec1 = - vld1_u16(reinterpret_cast(&vec1[idx])); - const auto temp_vec2 = - vld1_u16(reinterpret_cast(&vec2[idx])); - *tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2); + const auto temp_vec1 = vec::Vectorized::loadu(&vec1[idx]); + const auto temp_vec2 = vec::Vectorized::loadu(&vec2[idx]); + *tail_sum = fmadd(*tail_sum, temp_vec1, temp_vec2); } -namespace { +template +C10_ALWAYS_INLINE auto +dot_with_fp32_arith_main_loop_no_bfdot( + const T* vec1, + const T* vec2, + int64_t len) { + vec::VectorizedN sum(0); + const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); + for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) { + const auto* vec1_ = vec1 + j; + const auto* vec2_ = vec2 + j; + c10::ForcedUnroll{}([vec1_, vec2_, &sum](auto k) C10_ALWAYS_INLINE_ATTRIBUTE { + dot_with_fp32_arith_main_inner_loop_no_bfdot(vec1_, vec2_, sum, k); + }); + } + return reduce(sum); +} + +#if COMPILER_SUPPORTS_BF16_TARGET template struct ForcedUnrollTargetBFloat16 { template - ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator()(const Func& f) const { + TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator()(const Func& f) const { ForcedUnrollTargetBFloat16{}(f); f(n - 1); } @@ -156,59 +233,89 @@ struct ForcedUnrollTargetBFloat16 { template <> struct ForcedUnrollTargetBFloat16<1> { template - ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator()(const Func& f) const { + TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator()(const Func& f) const { f(0); } }; -} // namespace - -template -ET_TARGET_ARM_BF16_ATTRIBUTE float -dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) { - float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)}; +C10_ALWAYS_INLINE TARGET_ARM_BF16_ATTRIBUTE auto +dot_with_fp32_arith_main_loop_bfdot( + const BFloat16* vec1, + const BFloat16* vec2, + int64_t len) { + vec::VectorizedN sum(0); const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); - for (int j = 0; j < len_aligned; j += kF32ElementsPerIteration) { + for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) { const auto* vec1_ = vec1 + j; const auto* vec2_ = vec2 + j; - ForcedUnrollTargetBFloat16{}( - [vec1_, vec2_, &sum](auto k) - ET_INLINE_ATTRIBUTE ET_TARGET_ARM_BF16_ATTRIBUTE { - dot_with_fp32_arith_main_inner_loop( - vec1_, vec2_, sum, k); - }); + ForcedUnrollTargetBFloat16{}([vec1_, vec2_, &sum](auto k) + C10_ALWAYS_INLINE_ATTRIBUTE TARGET_ARM_BF16_ATTRIBUTE { + dot_with_fp32_arith_main_inner_loop_bfdot(vec1_, vec2_, sum, k); + }); } - auto reducedSum = reduce(sum); - - // First-tier tail fixup: make sure we handle workloads that can - // benefit from vectorization, but don't fit into our fully unrolled - // loop above. - float32x4_t tailSum = vdupq_n_f32(0); - const auto len_aligned_4 = len & ~3; - for (int j = len_aligned; j < len_aligned_4; j += 4) { - dot_with_fp32_arith_vectorized_tail_inner_loop(vec1, vec2, &tailSum, j); - } - auto reducedTail = vpaddq_f32(tailSum, tailSum); - reducedSum += vgetq_lane_f32(vpaddq_f32(reducedTail, reducedTail), 0); + return reduce(sum); +} +#endif // COMPILER_SUPPORTS_BF16_TARGET - // Second-tier tail fixup: handle all workloads. - for (int j = len_aligned_4; j < len; ++j) { - reducedSum += vec1[j] * vec2[j]; - } - return reducedSum; +static_assert( + (vec::Vectorized::size() & (vec::Vectorized::size() - 1)) == 0, + "Below code expects power-of-2 vector register size!"); + +// NOTE [GCC code duplication]: The first attempt at landing BFDOT support with +// TARGET_ARM_BF16_ATTRIBUTE failed because unlike clang, GCC will not +// allow inlining a non-bf16-specific function into a bf16-specific +// function. We can work around this by duplicating the code into the +// bfdot and non-bfdot callsites. The code is in this macro to avoid +// actual copy/paste. +#define DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(bfdot_suffix) \ + /* First-tier tail fixup: make sure we handle workloads that can */ \ + /* benefit from vectorization, but don't fit into our fully unrolled */ \ + /* loop above. */ \ + vec::Vectorized tail_sum(0); \ + const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); \ + const auto len_aligned_vec = len & ~(vec::Vectorized::size() - 1); \ + for (int j = len_aligned; j < len_aligned_vec; j += vec::Vectorized::size()) { \ + dot_with_fp32_arith_vectorized_tail_inner_loop##bfdot_suffix(vec1, vec2, &tail_sum, j); \ + } \ + reduced_sum += reduce(tail_sum); \ + \ + /* Second-tier tail fixup: handle all workloads. */ \ + for (const auto j : c10::irange(len_aligned_vec, len)) { \ + /* Attempting to use Half here caused multiple test failures; */ \ + /* using float to unbreak. (Suspect we need a scalar FMA.) */ \ + float x1 = vec1[j]; \ + float x2 = vec2[j]; \ + reduced_sum += x1 * x2; \ + } \ + return reduced_sum + +#if COMPILER_SUPPORTS_BF16_TARGET +TARGET_ARM_BF16_ATTRIBUTE float +dot_with_fp32_arith_bfdot(const BFloat16* vec1, const BFloat16* vec2, int64_t len) { + auto reduced_sum = dot_with_fp32_arith_main_loop_bfdot(vec1, vec2, len); + DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(_bfdot); } +#endif // COMPILER_SUPPORTS_BF16_TARGET -float bf16_dot_with_fp32_arith( - const BFloat16* vec1, - const BFloat16* vec2, - int64_t len) { +template +C10_ALWAYS_INLINE float +dot_with_fp32_arith_no_bfdot(const T* vec1, const T* vec2, int64_t len) { + auto reduced_sum = dot_with_fp32_arith_main_loop_no_bfdot(vec1, vec2, len); + DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(_no_bfdot); +} +#undef DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY + +} // namespace + +float bf16_dot_with_fp32_arith(const at::BFloat16* vec1, const at::BFloat16* vec2, int64_t len) { +#if COMPILER_SUPPORTS_BF16_TARGET if (cpuinfo_has_arm_bf16()) { - return dot_with_fp32_arith(vec1, vec2, len); - } else { - return dot_with_fp32_arith(vec1, vec2, len); + return dot_with_fp32_arith_bfdot(vec1, vec2, len); + } else +#endif // COMPILER_SUPPORTS_BF16_TARGET + { + return dot_with_fp32_arith_no_bfdot(vec1, vec2, len); } } -#endif // __aarch64__ -} // namespace internal -} // namespace cpublas -} // namespace executorch + +} // namespace executorch::cpublas::internal diff --git a/kernels/optimized/lib_defs.bzl b/kernels/optimized/lib_defs.bzl index 85886365a01..7e8014f26e2 100644 --- a/kernels/optimized/lib_defs.bzl +++ b/kernels/optimized/lib_defs.bzl @@ -200,7 +200,7 @@ def define_libs(is_fbcode=False): exported_headers = native.glob([ "blas/**/*.h", ]), - compiler_flags = get_compiler_optimization_flags(), + compiler_flags = ["-Wno-pass-failed"] + get_compiler_optimization_flags(), header_namespace = "executorch/kernels/optimized", visibility = [ "//executorch/...", @@ -235,6 +235,7 @@ def define_libs(is_fbcode=False): "//executorch/extension/threadpool:threadpool", "//executorch/kernels/optimized:libutils", "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", ], **get_apple_framework_deps_kwargs(is_fbcode), ) From a3c268a7b9e0c0f351cb64ed0c7e901123ba39f6 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 14 May 2025 16:16:26 -0700 Subject: [PATCH 2/5] Update on "Mostly sync BlasKernel.cpp with ATen ReducedPrecisionGemvFastPathKernel" The two files were similar, but diverged due to recent changes. Since we have sharing of PyTorch headers, we can keep them mostly the same; differences are some of the namespace stuff and a couple of EXECUTORCH NOTEs. Differential Revision: [D74702689](https://our.internmc.facebook.com/intern/diff/D74702689/) [ghstack-poisoned] --- kernels/optimized/blas/BlasKernel.cpp | 8 ++++---- kernels/optimized/lib_defs.bzl | 11 ++++++----- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/kernels/optimized/blas/BlasKernel.cpp b/kernels/optimized/blas/BlasKernel.cpp index 209b700a5e4..f63cf3f6387 100644 --- a/kernels/optimized/blas/BlasKernel.cpp +++ b/kernels/optimized/blas/BlasKernel.cpp @@ -95,10 +95,13 @@ float reduce(vec::VectorizedN& x) { return reduce(x[0]); } +// EXECUTORCH NOTE: removed __ARM_FEATURE_BF16_VECTOR_ARITHMETIC gate +// added in https://github.com/pytorch/pytorch/pull/152766, which I +// complained on. + // We would have to write a separate SVE-specific path to use SVE // BFDOT. Deferring that for now to get the NEON/ASIMD BFDOT path // working. -#if __ARM_FEATURE_BF16_VECTOR_ARITHMETIC #if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15 // https://godbolt.org/z/z8P4Yncra #define COMPILER_SUPPORTS_BF16_TARGET 1 @@ -109,9 +112,6 @@ float reduce(vec::VectorizedN& x) { #else // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15 #define COMPILER_SUPPORTS_BF16_TARGET 0 #endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15 -#else // __ARM_FEATURE_BF16_VECTOR_ARITHMETIC -#define COMPILER_SUPPORTS_BF16_TARGET 0 -#endif // __ARM_FEATURE_BF16_VECTOR_ARITHMETIC #if COMPILER_SUPPORTS_BF16_TARGET #define TARGET_ARM_BF16_ATTRIBUTE __attribute__((target("arch=armv8.2-a+bf16"))) diff --git a/kernels/optimized/lib_defs.bzl b/kernels/optimized/lib_defs.bzl index 7e8014f26e2..d960db852bb 100644 --- a/kernels/optimized/lib_defs.bzl +++ b/kernels/optimized/lib_defs.bzl @@ -2,10 +2,6 @@ load("@fbsource//tools/build_defs:default_platform_defs.bzl", "DEVSERVER_PLATFOR load("@fbsource//tools/build_defs:fb_native_wrapper.bzl", "fb_native") load("@fbsource//xplat/executorch/backends/xnnpack/third-party:third_party_libs.bzl", "third_party_dep") load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") -load( - "@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl", - "get_compiler_optimization_flags", -) # Because vec exists as a collection of header files, compile and preprocessor # flags applied to the vec target do not have any effect, since no compilation @@ -200,7 +196,12 @@ def define_libs(is_fbcode=False): exported_headers = native.glob([ "blas/**/*.h", ]), - compiler_flags = ["-Wno-pass-failed"] + get_compiler_optimization_flags(), + compiler_flags = ["-Wno-pass-failed"] + select({ + "ovr_config//runtime:fbcode": [], + # TODO: replace with get_compiler_optimization_flags from op_registration_util.bzl when that + # is re-enabled. + "DEFAULT": ["-Os"], + }), header_namespace = "executorch/kernels/optimized", visibility = [ "//executorch/...", From 4a91e45b8c670f58b99b565f9944d87c015c7e79 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 14 May 2025 16:50:04 -0700 Subject: [PATCH 3/5] fix CMake on "Mostly sync BlasKernel.cpp with ATen ReducedPrecisionGemvFastPathKernel" The two files were similar, but diverged due to recent changes. Since we have sharing of PyTorch headers, we can keep them mostly the same; differences are some of the namespace stuff and a couple of EXECUTORCH NOTEs. Differential Revision: [D74702689](https://our.internmc.facebook.com/intern/diff/D74702689/) [ghstack-poisoned] --- kernels/optimized/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/kernels/optimized/CMakeLists.txt b/kernels/optimized/CMakeLists.txt index ae6d8e6fcd3..bbd6999ad33 100644 --- a/kernels/optimized/CMakeLists.txt +++ b/kernels/optimized/CMakeLists.txt @@ -42,6 +42,7 @@ endif() # Build cpublas. list(TRANSFORM _optimized_cpublas__srcs PREPEND "${EXECUTORCH_ROOT}/") add_library(cpublas STATIC ${_optimized_cpublas__srcs}) +target_include_directories(cpublas PRIVATE ${TORCH_INCLUDE_DIRS}) target_link_libraries( cpublas PUBLIC executorch_core eigen_blas extension_threadpool ) From 294798bab0d554d5dc4e171d95b029581a544db6 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 15 May 2025 11:00:28 -0700 Subject: [PATCH 4/5] lintrunner on "Mostly sync BlasKernel.cpp with ATen ReducedPrecisionGemvFastPathKernel" The two files were similar, but diverged due to recent changes. Since we have sharing of PyTorch headers, we can keep them mostly the same; differences are some of the namespace stuff and a couple of EXECUTORCH NOTEs. Differential Revision: [D74702689](https://our.internmc.facebook.com/intern/diff/D74702689/) [ghstack-poisoned] --- kernels/optimized/blas/BlasKernel.cpp | 173 +++++++++++++++----------- 1 file changed, 99 insertions(+), 74 deletions(-) diff --git a/kernels/optimized/blas/BlasKernel.cpp b/kernels/optimized/blas/BlasKernel.cpp index f63cf3f6387..4d833db65f6 100644 --- a/kernels/optimized/blas/BlasKernel.cpp +++ b/kernels/optimized/blas/BlasKernel.cpp @@ -33,7 +33,8 @@ namespace executorch::cpublas::internal { constexpr auto kF32RegisterPairsPerIteration = 4; constexpr auto kF32RegistersPerIteration = kF32RegisterPairsPerIteration * 2; constexpr auto kF32ElementsPerRegister = vec::Vectorized::size(); -constexpr auto kF32ElementsPerIteration = kF32RegistersPerIteration * kF32ElementsPerRegister; +constexpr auto kF32ElementsPerIteration = + kF32RegistersPerIteration * kF32ElementsPerRegister; namespace { template @@ -58,8 +59,8 @@ constexpr int IntegerLog2(T n, int p = 0) { * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, @@ -74,9 +75,7 @@ float reduce(vec::Vectorized x) { #if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) return vaddvq_f32(x); #else - return vec::vec_reduce_all( - std::plus>(), - x); + return vec::vec_reduce_all(std::plus>(), x); #endif } @@ -86,12 +85,13 @@ float reduce(vec::Vectorized x) { // required notice. float reduce(vec::VectorizedN& x) { int offset = kF32RegistersPerIteration; - c10::ForcedUnroll{}([&offset, &x](auto idx) { - offset /= 2; - for (const auto i : c10::irange(offset)) { - x[i] = x[i] + x[offset + i]; - } - }); + c10::ForcedUnroll{}( + [&offset, &x](auto idx) { + offset /= 2; + for (const auto i : c10::irange(offset)) { + x[i] = x[i] + x[offset + i]; + } + }); return reduce(x[0]); } @@ -102,16 +102,20 @@ float reduce(vec::VectorizedN& x) { // We would have to write a separate SVE-specific path to use SVE // BFDOT. Deferring that for now to get the NEON/ASIMD BFDOT path // working. -#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15 +#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && \ + defined(__clang__) && __clang_major__ > 15 // https://godbolt.org/z/z8P4Yncra #define COMPILER_SUPPORTS_BF16_TARGET 1 -#elif defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 10 +#elif defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && \ + !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 10 // https://gcc.gnu.org/gcc-10/changes.html // https://godbolt.org/z/cdGG7vn8o #define COMPILER_SUPPORTS_BF16_TARGET 1 -#else // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15 +#else // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && + // defined(__clang__) && __clang_major__ > 15 #define COMPILER_SUPPORTS_BF16_TARGET 0 -#endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15 +#endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && + // defined(__clang__) && __clang_major__ > 15 #if COMPILER_SUPPORTS_BF16_TARGET #define TARGET_ARM_BF16_ATTRIBUTE __attribute__((target("arch=armv8.2-a+bf16"))) @@ -128,25 +132,25 @@ dot_with_fp32_arith_main_inner_loop_bfdot( // bfloat16x8_t. I suspect a bug or incomplete // __attribute__((target)) implementation. Intrinsics should be fine // because we're using vbfdotq_f32 below anyway. - const auto temp_vec1 = vld1q_bf16( - reinterpret_cast( - &vec1[registerPairIndex * vec::Vectorized::size()])); - const auto temp_vec2 = vld1q_bf16( - reinterpret_cast( - &vec2[registerPairIndex * vec::Vectorized::size()])); + const auto temp_vec1 = vld1q_bf16(reinterpret_cast( + &vec1[registerPairIndex * vec::Vectorized::size()])); + const auto temp_vec2 = vld1q_bf16(reinterpret_cast( + &vec2[registerPairIndex * vec::Vectorized::size()])); sum[registerPairIndex] = - vbfdotq_f32(sum[registerPairIndex], temp_vec1, temp_vec2); + vbfdotq_f32(sum[registerPairIndex], temp_vec1, temp_vec2); } -TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE -void dot_with_fp32_arith_vectorized_tail_inner_loop_bfdot( +TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void +dot_with_fp32_arith_vectorized_tail_inner_loop_bfdot( const at::BFloat16* vec1, const at::BFloat16* vec2, vec::Vectorized* tail_sum, int idx) { // See NOTE[Intrinsics in bfdot variant] above. - const auto temp_vec1 = vld1q_bf16(reinterpret_cast(&vec1[idx])); - const auto temp_vec2 = vld1q_bf16(reinterpret_cast(&vec2[idx])); + const auto temp_vec1 = + vld1q_bf16(reinterpret_cast(&vec1[idx])); + const auto temp_vec2 = + vld1q_bf16(reinterpret_cast(&vec2[idx])); *tail_sum = vbfdotq_f32(*tail_sum, temp_vec1, temp_vec2); } @@ -156,14 +160,17 @@ void dot_with_fp32_arith_vectorized_tail_inner_loop_bfdot( namespace { -[[maybe_unused]] std::pair, vec::Vectorized> fmadd( +[[maybe_unused]] std::pair, vec::Vectorized> +fmadd( const vec::Vectorized& a, const vec::Vectorized& b, const vec::Vectorized& acc_low, const vec::Vectorized& acc_high) { const auto [a_float_low, a_float_high] = convert_bfloat16_float(a); const auto [b_float_low, b_float_high] = convert_bfloat16_float(b); - return std::make_pair(fmadd(a_float_low, b_float_low, acc_low), fmadd(a_float_high, b_float_high, acc_high)); + return std::make_pair( + fmadd(a_float_low, b_float_low, acc_low), + fmadd(a_float_high, b_float_high, acc_high)); } [[maybe_unused]] vec::Vectorized fmadd( @@ -172,21 +179,28 @@ namespace { const vec::Vectorized& b) { const auto [a_float_low, a_float_high] = convert_bfloat16_float(a); const auto [b_float_low, b_float_high] = convert_bfloat16_float(b); - return fmadd(a_float_high, b_float_high, fmadd(a_float_low, b_float_low, acc)); + return fmadd( + a_float_high, b_float_high, fmadd(a_float_low, b_float_low, acc)); } } // namespace template C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot( - const T* vec1, - const T* vec2, - vec::VectorizedN& sum, - int registerPairIndex) { + const T* vec1, + const T* vec2, + vec::VectorizedN& sum, + int registerPairIndex) { static_assert(std::is_same_v); - const auto temp_vec1 = vec::Vectorized::loadu(&vec1[registerPairIndex * vec::Vectorized::size()]); - const auto temp_vec2 = vec::Vectorized::loadu(&vec2[registerPairIndex * vec::Vectorized::size()]); - - const auto [result_low, result_high] = fmadd(temp_vec1, temp_vec2, sum[2 * registerPairIndex], sum[2 * registerPairIndex + 1]); + const auto temp_vec1 = vec::Vectorized::loadu( + &vec1[registerPairIndex * vec::Vectorized::size()]); + const auto temp_vec2 = vec::Vectorized::loadu( + &vec2[registerPairIndex * vec::Vectorized::size()]); + + const auto [result_low, result_high] = fmadd( + temp_vec1, + temp_vec2, + sum[2 * registerPairIndex], + sum[2 * registerPairIndex + 1]); sum[2 * registerPairIndex] = result_low; sum[2 * registerPairIndex + 1] = result_high; } @@ -203,19 +217,19 @@ C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop_no_bfdot( } template -C10_ALWAYS_INLINE auto -dot_with_fp32_arith_main_loop_no_bfdot( +C10_ALWAYS_INLINE auto dot_with_fp32_arith_main_loop_no_bfdot( const T* vec1, const T* vec2, int64_t len) { vec::VectorizedN sum(0); const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); - for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) { + for (int j = 0; j < len_aligned; j += kF32ElementsPerIteration) { const auto* vec1_ = vec1 + j; const auto* vec2_ = vec2 + j; - c10::ForcedUnroll{}([vec1_, vec2_, &sum](auto k) C10_ALWAYS_INLINE_ATTRIBUTE { - dot_with_fp32_arith_main_inner_loop_no_bfdot(vec1_, vec2_, sum, k); - }); + c10::ForcedUnroll{}( + [vec1_, vec2_, &sum](auto k) C10_ALWAYS_INLINE_ATTRIBUTE { + dot_with_fp32_arith_main_inner_loop_no_bfdot(vec1_, vec2_, sum, k); + }); } return reduce(sum); } @@ -224,7 +238,8 @@ dot_with_fp32_arith_main_loop_no_bfdot( template struct ForcedUnrollTargetBFloat16 { template - TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator()(const Func& f) const { + TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator()( + const Func& f) const { ForcedUnrollTargetBFloat16{}(f); f(n - 1); } @@ -233,7 +248,8 @@ struct ForcedUnrollTargetBFloat16 { template <> struct ForcedUnrollTargetBFloat16<1> { template - TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator()(const Func& f) const { + TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator()( + const Func& f) const { f(0); } }; @@ -245,20 +261,22 @@ dot_with_fp32_arith_main_loop_bfdot( int64_t len) { vec::VectorizedN sum(0); const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); - for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) { + for (int j = 0; j < len_aligned; j += kF32ElementsPerIteration) { const auto* vec1_ = vec1 + j; const auto* vec2_ = vec2 + j; - ForcedUnrollTargetBFloat16{}([vec1_, vec2_, &sum](auto k) - C10_ALWAYS_INLINE_ATTRIBUTE TARGET_ARM_BF16_ATTRIBUTE { - dot_with_fp32_arith_main_inner_loop_bfdot(vec1_, vec2_, sum, k); - }); + ForcedUnrollTargetBFloat16{}( + [vec1_, vec2_, &sum](auto k) + C10_ALWAYS_INLINE_ATTRIBUTE TARGET_ARM_BF16_ATTRIBUTE { + dot_with_fp32_arith_main_inner_loop_bfdot(vec1_, vec2_, sum, k); + }); } return reduce(sum); } #endif // COMPILER_SUPPORTS_BF16_TARGET static_assert( - (vec::Vectorized::size() & (vec::Vectorized::size() - 1)) == 0, + (vec::Vectorized::size() & + (vec::Vectorized::size() - 1)) == 0, "Below code expects power-of-2 vector register size!"); // NOTE [GCC code duplication]: The first attempt at landing BFDOT support with @@ -267,31 +285,35 @@ static_assert( // function. We can work around this by duplicating the code into the // bfdot and non-bfdot callsites. The code is in this macro to avoid // actual copy/paste. -#define DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(bfdot_suffix) \ - /* First-tier tail fixup: make sure we handle workloads that can */ \ - /* benefit from vectorization, but don't fit into our fully unrolled */ \ - /* loop above. */ \ - vec::Vectorized tail_sum(0); \ - const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); \ +#define DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(bfdot_suffix) \ + /* First-tier tail fixup: make sure we handle workloads that can */ \ + /* benefit from vectorization, but don't fit into our fully unrolled */ \ + /* loop above. */ \ + vec::Vectorized tail_sum(0); \ + const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); \ const auto len_aligned_vec = len & ~(vec::Vectorized::size() - 1); \ - for (int j = len_aligned; j < len_aligned_vec; j += vec::Vectorized::size()) { \ - dot_with_fp32_arith_vectorized_tail_inner_loop##bfdot_suffix(vec1, vec2, &tail_sum, j); \ - } \ - reduced_sum += reduce(tail_sum); \ - \ - /* Second-tier tail fixup: handle all workloads. */ \ - for (const auto j : c10::irange(len_aligned_vec, len)) { \ - /* Attempting to use Half here caused multiple test failures; */ \ - /* using float to unbreak. (Suspect we need a scalar FMA.) */ \ - float x1 = vec1[j]; \ - float x2 = vec2[j]; \ - reduced_sum += x1 * x2; \ - } \ + for (int j = len_aligned; j < len_aligned_vec; \ + j += vec::Vectorized::size()) { \ + dot_with_fp32_arith_vectorized_tail_inner_loop##bfdot_suffix( \ + vec1, vec2, &tail_sum, j); \ + } \ + reduced_sum += reduce(tail_sum); \ + \ + /* Second-tier tail fixup: handle all workloads. */ \ + for (const auto j : c10::irange(len_aligned_vec, len)) { \ + /* Attempting to use Half here caused multiple test failures; */ \ + /* using float to unbreak. (Suspect we need a scalar FMA.) */ \ + float x1 = vec1[j]; \ + float x2 = vec2[j]; \ + reduced_sum += x1 * x2; \ + } \ return reduced_sum #if COMPILER_SUPPORTS_BF16_TARGET -TARGET_ARM_BF16_ATTRIBUTE float -dot_with_fp32_arith_bfdot(const BFloat16* vec1, const BFloat16* vec2, int64_t len) { +TARGET_ARM_BF16_ATTRIBUTE float dot_with_fp32_arith_bfdot( + const BFloat16* vec1, + const BFloat16* vec2, + int64_t len) { auto reduced_sum = dot_with_fp32_arith_main_loop_bfdot(vec1, vec2, len); DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(_bfdot); } @@ -307,7 +329,10 @@ dot_with_fp32_arith_no_bfdot(const T* vec1, const T* vec2, int64_t len) { } // namespace -float bf16_dot_with_fp32_arith(const at::BFloat16* vec1, const at::BFloat16* vec2, int64_t len) { +float bf16_dot_with_fp32_arith( + const at::BFloat16* vec1, + const at::BFloat16* vec2, + int64_t len) { #if COMPILER_SUPPORTS_BF16_TARGET if (cpuinfo_has_arm_bf16()) { return dot_with_fp32_arith_bfdot(vec1, vec2, len); From 27af7f6683739da56260f205f9903fe8f791703e Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 15 May 2025 14:16:17 -0700 Subject: [PATCH 5/5] also enable for non-aarch64 on "Mostly sync BlasKernel.cpp with ATen ReducedPrecisionGemvFastPathKernel" The two files were similar, but diverged due to recent changes. Since we have sharing of PyTorch headers, we can keep them mostly the same; differences are some of the namespace stuff, lintrunner, and a couple of EXECUTORCH NOTEs. Differential Revision: [D74702689](https://our.internmc.facebook.com/intern/diff/D74702689/) [ghstack-poisoned] --- kernels/optimized/blas/BlasKernel.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/kernels/optimized/blas/BlasKernel.h b/kernels/optimized/blas/BlasKernel.h index fc47b4482d6..1332a881ed5 100644 --- a/kernels/optimized/blas/BlasKernel.h +++ b/kernels/optimized/blas/BlasKernel.h @@ -158,7 +158,6 @@ void gemm_transa_( } } -#ifdef __aarch64__ namespace internal { float bf16_dot_with_fp32_arith(const torch::executor::BFloat16* vec1, const torch::executor::BFloat16* vec2, int64_t len); } // namespace internal @@ -204,7 +203,6 @@ inline void gemm_transa_( } }); } -#endif // clang-format on