Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 38 additions & 23 deletions kernels/optimized/blas/BlasKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#ifdef __aarch64__
#include <arm_neon.h>
#include <cpuinfo.h>
#endif

using torch::executor::BFloat16;
Expand Down Expand Up @@ -80,32 +81,37 @@ f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) {
}
#endif

template <bool useBfloat16Dot>
static ET_INLINE void dot_with_fp32_arith_main_inner_loop(
const BFloat16* vec1,
const BFloat16* vec2,
float32x4_t sum[kF32RegistersPerIteration],
int registerPairIndex) {
#ifdef __ARM_FEATURE_BF16
const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast<const __bf16*>(
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast<const __bf16*>(
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
sum[registerPairIndex] =
f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2);
#else
const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));

sum[2 * registerPairIndex] = f32_fma_bf16(
sum[2 * registerPairIndex],
vget_low_u16(temp_vec1),
vget_low_u16(temp_vec2));
sum[2 * registerPairIndex + 1] = f32_fma_bf16(
sum[2 * registerPairIndex + 1],
vget_high_u16(temp_vec1),
vget_high_u16(temp_vec2));
if (useBfloat16Dot) {
const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast<const __bf16*>(
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast<const __bf16*>(
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
sum[registerPairIndex] =
f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2);
} else {
#endif
const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));

sum[2 * registerPairIndex] = f32_fma_bf16(
sum[2 * registerPairIndex],
vget_low_u16(temp_vec1),
vget_low_u16(temp_vec2));
sum[2 * registerPairIndex + 1] = f32_fma_bf16(
sum[2 * registerPairIndex + 1],
vget_high_u16(temp_vec1),
vget_high_u16(temp_vec2));
#ifdef __ARM_FEATURE_BF16
}
#endif
}

Expand All @@ -121,7 +127,7 @@ static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
*tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2);
}

template <typename T>
template <typename T, bool useBfloat16Dot>
float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)};
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
Expand All @@ -130,7 +136,8 @@ float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
const auto* vec2_ = vec2 + j;
utils::ForcedUnroll<kF32RegisterPairsPerIteration>{}(
[vec1_, vec2_, &sum](auto k) ET_INLINE_ATTRIBUTE {
dot_with_fp32_arith_main_inner_loop(vec1_, vec2_, sum, k);
dot_with_fp32_arith_main_inner_loop<useBfloat16Dot>(
vec1_, vec2_, sum, k);
});
}
auto reducedSum = reduce(sum);
Expand All @@ -157,7 +164,15 @@ float bf16_dot_with_fp32_arith(
const BFloat16* vec1,
const BFloat16* vec2,
int64_t len) {
return dot_with_fp32_arith(vec1, vec2, len);
#ifdef __ARM_FEATURE_BF16
if (cpuinfo_has_arm_bf16()) {
return dot_with_fp32_arith<BFloat16, true>(vec1, vec2, len);
} else {
#endif
return dot_with_fp32_arith<BFloat16, false>(vec1, vec2, len);
#ifdef __ARM_FEATURE_BF16
}
#endif
}
#endif
} // namespace internal
Expand Down
11 changes: 11 additions & 0 deletions kernels/optimized/lib_defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,14 @@ def define_libs():
] if not runtime.is_oss else [],
"DEFAULT": [],
}),
fbandroid_platform_compiler_flags = [
(
"^android-arm64.*$",
[
"-march=armv8+bf16",
],
),
],
fbandroid_platform_preprocessor_flags = [
(
"^android-arm64.*$",
Expand All @@ -145,6 +153,9 @@ def define_libs():
],
),
],
fbobjc_compiler_flags = [
"-march=armv8+bf16",
],
fbobjc_exported_preprocessor_flags = [
"-DET_BUILD_WITH_BLAS",
"-DET_BUILD_FOR_APPLE",
Expand Down
8 changes: 4 additions & 4 deletions kernels/test/op_linear_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,16 @@ class OpLinearOutTest : public OperatorTest {
}
}

// matmul gives 4 * 2 * 3 = 24
Tensor x = tf.full({3, 4}, 2);
Tensor y = tf.full({5, 4}, 3);
// matmul gives 32 * 2 * 3 = 192
Tensor x = tf.full({3, 32}, 2);
Tensor y = tf.full({5, 32}, 3);

// Output shape should be (3, 5)
Tensor out = tf.zeros({3, 5});

op_linear_out(x, y, out);

Tensor expected = tf.full({3, 5}, 24);
Tensor expected = tf.full({3, 5}, 192);

EXPECT_TENSOR_EQ(out, expected);
}
Expand Down
3 changes: 2 additions & 1 deletion shim/xplat/executorch/build/env_interface.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def _remove_platform_specific_args(kwargs):
"""
keys = []
for key in kwargs:
if key.endswith("_platform_preprocessor_flags") or key.endswith("_platform_deps") or key.startswith("fbobjc"):
if (key.endswith("_platform_preprocessor_flags") or key.endswith("_platform_deps") or
key.startswith("fbobjc") or key.endswith("_platform_compiler_flags")):
keys.append(key)
for key in keys:
kwargs.pop(key)
Expand Down
Loading