diff --git a/kernels/optimized/CMakeLists.txt b/kernels/optimized/CMakeLists.txt index ae6d8e6fcd3..bbd6999ad33 100644 --- a/kernels/optimized/CMakeLists.txt +++ b/kernels/optimized/CMakeLists.txt @@ -42,6 +42,7 @@ endif() # Build cpublas. list(TRANSFORM _optimized_cpublas__srcs PREPEND "${EXECUTORCH_ROOT}/") add_library(cpublas STATIC ${_optimized_cpublas__srcs}) +target_include_directories(cpublas PRIVATE ${TORCH_INCLUDE_DIRS}) target_link_libraries( cpublas PUBLIC executorch_core eigen_blas extension_threadpool ) diff --git a/kernels/optimized/blas/BlasKernel.cpp b/kernels/optimized/blas/BlasKernel.cpp index a3e2172504d..4d833db65f6 100644 --- a/kernels/optimized/blas/BlasKernel.cpp +++ b/kernels/optimized/blas/BlasKernel.cpp @@ -6,148 +6,240 @@ * LICENSE file in the root directory of this source tree. */ +// NOTE: This file is mostly the same as +// ReducedPrecisionFloatGemvFastPathKernel.cpp in PyTorch. Actually +// sharing the two versions is a TODO. #include +#include +#include +#include + +#include +#include +#include +#include #ifdef __aarch64__ #include #include #endif +namespace vec = at::vec; +using executorch::extension::parallel_for; using torch::executor::BFloat16; +using torch::executor::Half; -namespace executorch { -namespace cpublas { -namespace internal { -#ifdef __aarch64__ -static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) { -#ifdef __ARM_FEATURE_FMA - return vfmaq_f32(a, b, c); +namespace executorch::cpublas::internal { +constexpr auto kF32RegisterPairsPerIteration = 4; +constexpr auto kF32RegistersPerIteration = kF32RegisterPairsPerIteration * 2; +constexpr auto kF32ElementsPerRegister = vec::Vectorized::size(); +constexpr auto kF32ElementsPerIteration = + kF32RegistersPerIteration * kF32ElementsPerRegister; + +namespace { +template +constexpr int IntegerLog2(T n, int p = 0) { + return (n <= 1) ? p : IntegerLog2(n / 2, p + 1); +} + +/* + * NOTE [ GGML Copyright Notice ] + * The below reduce overload and fp16_dot_with_fp16_arith function is + * adapted from llama.cpp's ggml_vec_dot_f16 and surrounding utility + * functions, so here is the required copyright notice: + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +float reduce(vec::Vectorized x) { +#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) + return vaddvq_f32(x); #else - return vaddq_f32(a, vmulq_f32(b, c)); -#endif // __ARM_FEATURE_FMA + return vec::vec_reduce_all(std::plus>(), x); +#endif } // The below reduce overload and fp16_dot_with_fp32_arith are adapted // from llama.cpp's ggml_vec_dot_f32 and surrounding utility // functions. See NOTE [ GGML Copyright Notice ] above for the // required notice. - -// We need the shift for reduce(), hence the extra constants. -static constexpr auto kF32ElementsPerIterationShift = 5; -static constexpr auto kF32ElementsPerIteration = 1 - << kF32ElementsPerIterationShift; -static_assert(kF32ElementsPerIteration == 32); - -static constexpr auto kF32ElementsPerRegisterShift = 2; -static constexpr auto kF32ElementsPerRegister = 1 - << kF32ElementsPerRegisterShift; -static_assert(kF32ElementsPerRegister == 4); - -static constexpr auto kF32RegisterPairsPerIteration = 4; -static constexpr auto kF32RegistersPerIteration = - kF32RegisterPairsPerIteration * 2; -static constexpr auto kF32RegistersPerIterationShift = 3; -static_assert( - kF32RegistersPerIteration == - kF32ElementsPerIteration / kF32ElementsPerRegister); -static_assert(kF32RegistersPerIteration == 1 << kF32RegistersPerIterationShift); - -static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) { +float reduce(vec::VectorizedN& x) { int offset = kF32RegistersPerIteration; - utils::ForcedUnroll{}( - [&offset, &x](auto idx) ET_INLINE_ATTRIBUTE { + c10::ForcedUnroll{}( + [&offset, &x](auto idx) { offset /= 2; - for (int i = 0; i < offset; ++i) { - x[i] = vaddq_f32(x[i], x[offset + i]); + for (const auto i : c10::irange(offset)) { + x[i] = x[i] + x[offset + i]; } }); - return vaddvq_f32(x[0]); + return reduce(x[0]); } -static ET_INLINE float32x4_t to_bfloat16(uint16x4_t u16) { - int32x4_t shift = vdupq_n_s32(16); - return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16), shift)); -} +// EXECUTORCH NOTE: removed __ARM_FEATURE_BF16_VECTOR_ARITHMETIC gate +// added in https://github.com/pytorch/pytorch/pull/152766, which I +// complained on. -static ET_INLINE float32x4_t -f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) { - return f32_fma(a, to_bfloat16(b), to_bfloat16(c)); -} +// We would have to write a separate SVE-specific path to use SVE +// BFDOT. Deferring that for now to get the NEON/ASIMD BFDOT path +// working. +#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && \ + defined(__clang__) && __clang_major__ > 15 +// https://godbolt.org/z/z8P4Yncra +#define COMPILER_SUPPORTS_BF16_TARGET 1 +#elif defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && \ + !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 10 +// https://gcc.gnu.org/gcc-10/changes.html +// https://godbolt.org/z/cdGG7vn8o +#define COMPILER_SUPPORTS_BF16_TARGET 1 +#else // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && + // defined(__clang__) && __clang_major__ > 15 +#define COMPILER_SUPPORTS_BF16_TARGET 0 +#endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && + // defined(__clang__) && __clang_major__ > 15 -#define ET_TARGET_ARM_BF16_ATTRIBUTE \ - __attribute__((target("arch=armv8.2-a+bf16"))) -ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE float32x4_t -f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) { - return vbfdotq_f32(a, b, c); -} +#if COMPILER_SUPPORTS_BF16_TARGET +#define TARGET_ARM_BF16_ATTRIBUTE __attribute__((target("arch=armv8.2-a+bf16"))) -ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void +TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop_bfdot( const BFloat16* vec1, const BFloat16* vec2, - float32x4_t sum[kF32RegistersPerIteration], + vec::VectorizedN& sum, int registerPairIndex) { - const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast( - &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); - const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast( - &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); + // NOTE[Intrinsics in bfdot variant]: We can't use + // vec::Vectorized::loadu here because linux-aarch64 GCC + // inexplicably can't convert Vectorized to + // bfloat16x8_t. I suspect a bug or incomplete + // __attribute__((target)) implementation. Intrinsics should be fine + // because we're using vbfdotq_f32 below anyway. + const auto temp_vec1 = vld1q_bf16(reinterpret_cast( + &vec1[registerPairIndex * vec::Vectorized::size()])); + const auto temp_vec2 = vld1q_bf16(reinterpret_cast( + &vec2[registerPairIndex * vec::Vectorized::size()])); sum[registerPairIndex] = - f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2); + vbfdotq_f32(sum[registerPairIndex], temp_vec1, temp_vec2); } -static ET_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot( - const BFloat16* vec1, - const BFloat16* vec2, - float32x4_t sum[kF32RegistersPerIteration], - int registerPairIndex) { - const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast( - &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); - const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast( - &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); +TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void +dot_with_fp32_arith_vectorized_tail_inner_loop_bfdot( + const at::BFloat16* vec1, + const at::BFloat16* vec2, + vec::Vectorized* tail_sum, + int idx) { + // See NOTE[Intrinsics in bfdot variant] above. + const auto temp_vec1 = + vld1q_bf16(reinterpret_cast(&vec1[idx])); + const auto temp_vec2 = + vld1q_bf16(reinterpret_cast(&vec2[idx])); + *tail_sum = vbfdotq_f32(*tail_sum, temp_vec1, temp_vec2); +} - sum[2 * registerPairIndex] = f32_fma_bf16( - sum[2 * registerPairIndex], - vget_low_u16(temp_vec1), - vget_low_u16(temp_vec2)); - sum[2 * registerPairIndex + 1] = f32_fma_bf16( - sum[2 * registerPairIndex + 1], - vget_high_u16(temp_vec1), - vget_high_u16(temp_vec2)); +#else +#define TARGET_ARM_BF16_ATTRIBUTE +#endif // COMPILER_SUPPORTS_BF16_TARGET + +namespace { + +[[maybe_unused]] std::pair, vec::Vectorized> +fmadd( + const vec::Vectorized& a, + const vec::Vectorized& b, + const vec::Vectorized& acc_low, + const vec::Vectorized& acc_high) { + const auto [a_float_low, a_float_high] = convert_bfloat16_float(a); + const auto [b_float_low, b_float_high] = convert_bfloat16_float(b); + return std::make_pair( + fmadd(a_float_low, b_float_low, acc_low), + fmadd(a_float_high, b_float_high, acc_high)); } -template -ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void -dot_with_fp32_arith_main_inner_loop( - const BFloat16* vec1, - const BFloat16* vec2, - float32x4_t sum[kF32RegistersPerIteration], +[[maybe_unused]] vec::Vectorized fmadd( + const vec::Vectorized& acc, + const vec::Vectorized& a, + const vec::Vectorized& b) { + const auto [a_float_low, a_float_high] = convert_bfloat16_float(a); + const auto [b_float_low, b_float_high] = convert_bfloat16_float(b); + return fmadd( + a_float_high, b_float_high, fmadd(a_float_low, b_float_low, acc)); +} +} // namespace + +template +C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot( + const T* vec1, + const T* vec2, + vec::VectorizedN& sum, int registerPairIndex) { - if constexpr (useBfdot) { - dot_with_fp32_arith_main_inner_loop_bfdot( - vec1, vec2, sum, registerPairIndex); - } else { - dot_with_fp32_arith_main_inner_loop_no_bfdot( - vec1, vec2, sum, registerPairIndex); - } + static_assert(std::is_same_v); + const auto temp_vec1 = vec::Vectorized::loadu( + &vec1[registerPairIndex * vec::Vectorized::size()]); + const auto temp_vec2 = vec::Vectorized::loadu( + &vec2[registerPairIndex * vec::Vectorized::size()]); + + const auto [result_low, result_high] = fmadd( + temp_vec1, + temp_vec2, + sum[2 * registerPairIndex], + sum[2 * registerPairIndex + 1]); + sum[2 * registerPairIndex] = result_low; + sum[2 * registerPairIndex + 1] = result_high; } -static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop( - const BFloat16* vec1, - const BFloat16* vec2, - float32x4_t* tailSum, +template +C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop_no_bfdot( + const T* vec1, + const T* vec2, + vec::Vectorized* tail_sum, int idx) { - const auto temp_vec1 = - vld1_u16(reinterpret_cast(&vec1[idx])); - const auto temp_vec2 = - vld1_u16(reinterpret_cast(&vec2[idx])); - *tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2); + const auto temp_vec1 = vec::Vectorized::loadu(&vec1[idx]); + const auto temp_vec2 = vec::Vectorized::loadu(&vec2[idx]); + *tail_sum = fmadd(*tail_sum, temp_vec1, temp_vec2); } -namespace { +template +C10_ALWAYS_INLINE auto dot_with_fp32_arith_main_loop_no_bfdot( + const T* vec1, + const T* vec2, + int64_t len) { + vec::VectorizedN sum(0); + const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); + for (int j = 0; j < len_aligned; j += kF32ElementsPerIteration) { + const auto* vec1_ = vec1 + j; + const auto* vec2_ = vec2 + j; + c10::ForcedUnroll{}( + [vec1_, vec2_, &sum](auto k) C10_ALWAYS_INLINE_ATTRIBUTE { + dot_with_fp32_arith_main_inner_loop_no_bfdot(vec1_, vec2_, sum, k); + }); + } + return reduce(sum); +} + +#if COMPILER_SUPPORTS_BF16_TARGET template struct ForcedUnrollTargetBFloat16 { template - ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator()(const Func& f) const { + TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator()( + const Func& f) const { ForcedUnrollTargetBFloat16{}(f); f(n - 1); } @@ -156,59 +248,99 @@ struct ForcedUnrollTargetBFloat16 { template <> struct ForcedUnrollTargetBFloat16<1> { template - ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator()(const Func& f) const { + TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator()( + const Func& f) const { f(0); } }; -} // namespace - -template -ET_TARGET_ARM_BF16_ATTRIBUTE float -dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) { - float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)}; +C10_ALWAYS_INLINE TARGET_ARM_BF16_ATTRIBUTE auto +dot_with_fp32_arith_main_loop_bfdot( + const BFloat16* vec1, + const BFloat16* vec2, + int64_t len) { + vec::VectorizedN sum(0); const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); for (int j = 0; j < len_aligned; j += kF32ElementsPerIteration) { const auto* vec1_ = vec1 + j; const auto* vec2_ = vec2 + j; ForcedUnrollTargetBFloat16{}( [vec1_, vec2_, &sum](auto k) - ET_INLINE_ATTRIBUTE ET_TARGET_ARM_BF16_ATTRIBUTE { - dot_with_fp32_arith_main_inner_loop( - vec1_, vec2_, sum, k); + C10_ALWAYS_INLINE_ATTRIBUTE TARGET_ARM_BF16_ATTRIBUTE { + dot_with_fp32_arith_main_inner_loop_bfdot(vec1_, vec2_, sum, k); }); } - auto reducedSum = reduce(sum); - - // First-tier tail fixup: make sure we handle workloads that can - // benefit from vectorization, but don't fit into our fully unrolled - // loop above. - float32x4_t tailSum = vdupq_n_f32(0); - const auto len_aligned_4 = len & ~3; - for (int j = len_aligned; j < len_aligned_4; j += 4) { - dot_with_fp32_arith_vectorized_tail_inner_loop(vec1, vec2, &tailSum, j); - } - auto reducedTail = vpaddq_f32(tailSum, tailSum); - reducedSum += vgetq_lane_f32(vpaddq_f32(reducedTail, reducedTail), 0); - - // Second-tier tail fixup: handle all workloads. - for (int j = len_aligned_4; j < len; ++j) { - reducedSum += vec1[j] * vec2[j]; - } - return reducedSum; + return reduce(sum); } +#endif // COMPILER_SUPPORTS_BF16_TARGET -float bf16_dot_with_fp32_arith( +static_assert( + (vec::Vectorized::size() & + (vec::Vectorized::size() - 1)) == 0, + "Below code expects power-of-2 vector register size!"); + +// NOTE [GCC code duplication]: The first attempt at landing BFDOT support with +// TARGET_ARM_BF16_ATTRIBUTE failed because unlike clang, GCC will not +// allow inlining a non-bf16-specific function into a bf16-specific +// function. We can work around this by duplicating the code into the +// bfdot and non-bfdot callsites. The code is in this macro to avoid +// actual copy/paste. +#define DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(bfdot_suffix) \ + /* First-tier tail fixup: make sure we handle workloads that can */ \ + /* benefit from vectorization, but don't fit into our fully unrolled */ \ + /* loop above. */ \ + vec::Vectorized tail_sum(0); \ + const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); \ + const auto len_aligned_vec = len & ~(vec::Vectorized::size() - 1); \ + for (int j = len_aligned; j < len_aligned_vec; \ + j += vec::Vectorized::size()) { \ + dot_with_fp32_arith_vectorized_tail_inner_loop##bfdot_suffix( \ + vec1, vec2, &tail_sum, j); \ + } \ + reduced_sum += reduce(tail_sum); \ + \ + /* Second-tier tail fixup: handle all workloads. */ \ + for (const auto j : c10::irange(len_aligned_vec, len)) { \ + /* Attempting to use Half here caused multiple test failures; */ \ + /* using float to unbreak. (Suspect we need a scalar FMA.) */ \ + float x1 = vec1[j]; \ + float x2 = vec2[j]; \ + reduced_sum += x1 * x2; \ + } \ + return reduced_sum + +#if COMPILER_SUPPORTS_BF16_TARGET +TARGET_ARM_BF16_ATTRIBUTE float dot_with_fp32_arith_bfdot( const BFloat16* vec1, const BFloat16* vec2, int64_t len) { + auto reduced_sum = dot_with_fp32_arith_main_loop_bfdot(vec1, vec2, len); + DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(_bfdot); +} +#endif // COMPILER_SUPPORTS_BF16_TARGET + +template +C10_ALWAYS_INLINE float +dot_with_fp32_arith_no_bfdot(const T* vec1, const T* vec2, int64_t len) { + auto reduced_sum = dot_with_fp32_arith_main_loop_no_bfdot(vec1, vec2, len); + DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(_no_bfdot); +} +#undef DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY + +} // namespace + +float bf16_dot_with_fp32_arith( + const at::BFloat16* vec1, + const at::BFloat16* vec2, + int64_t len) { +#if COMPILER_SUPPORTS_BF16_TARGET if (cpuinfo_has_arm_bf16()) { - return dot_with_fp32_arith(vec1, vec2, len); - } else { - return dot_with_fp32_arith(vec1, vec2, len); + return dot_with_fp32_arith_bfdot(vec1, vec2, len); + } else +#endif // COMPILER_SUPPORTS_BF16_TARGET + { + return dot_with_fp32_arith_no_bfdot(vec1, vec2, len); } } -#endif // __aarch64__ -} // namespace internal -} // namespace cpublas -} // namespace executorch + +} // namespace executorch::cpublas::internal diff --git a/kernels/optimized/blas/BlasKernel.h b/kernels/optimized/blas/BlasKernel.h index fc47b4482d6..1332a881ed5 100644 --- a/kernels/optimized/blas/BlasKernel.h +++ b/kernels/optimized/blas/BlasKernel.h @@ -158,7 +158,6 @@ void gemm_transa_( } } -#ifdef __aarch64__ namespace internal { float bf16_dot_with_fp32_arith(const torch::executor::BFloat16* vec1, const torch::executor::BFloat16* vec2, int64_t len); } // namespace internal @@ -204,7 +203,6 @@ inline void gemm_transa_( } }); } -#endif // clang-format on diff --git a/kernels/optimized/lib_defs.bzl b/kernels/optimized/lib_defs.bzl index 85886365a01..d960db852bb 100644 --- a/kernels/optimized/lib_defs.bzl +++ b/kernels/optimized/lib_defs.bzl @@ -2,10 +2,6 @@ load("@fbsource//tools/build_defs:default_platform_defs.bzl", "DEVSERVER_PLATFOR load("@fbsource//tools/build_defs:fb_native_wrapper.bzl", "fb_native") load("@fbsource//xplat/executorch/backends/xnnpack/third-party:third_party_libs.bzl", "third_party_dep") load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") -load( - "@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl", - "get_compiler_optimization_flags", -) # Because vec exists as a collection of header files, compile and preprocessor # flags applied to the vec target do not have any effect, since no compilation @@ -200,7 +196,12 @@ def define_libs(is_fbcode=False): exported_headers = native.glob([ "blas/**/*.h", ]), - compiler_flags = get_compiler_optimization_flags(), + compiler_flags = ["-Wno-pass-failed"] + select({ + "ovr_config//runtime:fbcode": [], + # TODO: replace with get_compiler_optimization_flags from op_registration_util.bzl when that + # is re-enabled. + "DEFAULT": ["-Os"], + }), header_namespace = "executorch/kernels/optimized", visibility = [ "//executorch/...", @@ -235,6 +236,7 @@ def define_libs(is_fbcode=False): "//executorch/extension/threadpool:threadpool", "//executorch/kernels/optimized:libutils", "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", ], **get_apple_framework_deps_kwargs(is_fbcode), )