Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions kernels/optimized/blas/BlasKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,26 @@ f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) {
return f32_fma(a, to_bfloat16(b), to_bfloat16(c));
}

#ifdef __ARM_FEATURE_BF16
static ET_INLINE float32x4_t
f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) {
return vbfdotq_f32(a, b, c);
}
#endif

static ET_INLINE void dot_with_fp32_arith_main_inner_loop(
const BFloat16* vec1,
const BFloat16* vec2,
float32x4_t sum[kF32RegistersPerIteration],
int registerPairIndex) {
// TODO: detect intrinsic availability, use them if they're available.
// __ARM_FEATURE_BF16 Load a pair of f32 registers at a time.
#ifdef __ARM_FEATURE_BF16
const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast<const __bf16*>(
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast<const __bf16*>(
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
sum[registerPairIndex] =
f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2);
#else
const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(
Expand All @@ -93,6 +106,7 @@ static ET_INLINE void dot_with_fp32_arith_main_inner_loop(
sum[2 * registerPairIndex + 1],
vget_high_u16(temp_vec1),
vget_high_u16(temp_vec2));
#endif
}

static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
Expand Down
Loading