1010
1111#ifdef __aarch64__
1212#include < arm_neon.h>
13+ #include < cpuinfo.h>
1314#endif
1415
1516using torch::executor::BFloat16;
@@ -23,7 +24,7 @@ static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) {
2324 return vfmaq_f32 (a, b, c);
2425#else
2526 return vaddq_f32 (a, vmulq_f32 (b, c));
26- #endif
27+ #endif // __ARM_FEATURE_FMA
2728}
2829
2930// The below reduce overload and fp16_dot_with_fp32_arith are adapted
@@ -78,35 +79,39 @@ static ET_INLINE float32x4_t
7879f32_dot_bf16 (float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) {
7980 return vbfdotq_f32 (a, b, c);
8081}
81- #endif
82+ #endif // __ARM_FEATURE_BF16
8283
84+ template <bool useBfloat16Dot>
8385static ET_INLINE void dot_with_fp32_arith_main_inner_loop (
8486 const BFloat16* vec1,
8587 const BFloat16* vec2,
8688 float32x4_t sum[kF32RegistersPerIteration ],
8789 int registerPairIndex) {
8890#ifdef __ARM_FEATURE_BF16
89- const bfloat16x8_t temp_vec1 = vld1q_bf16 (reinterpret_cast <const __bf16*>(
90- &vec1[registerPairIndex * 2 * kF32ElementsPerRegister ]));
91- const bfloat16x8_t temp_vec2 = vld1q_bf16 (reinterpret_cast <const __bf16*>(
92- &vec2[registerPairIndex * 2 * kF32ElementsPerRegister ]));
93- sum[registerPairIndex] =
94- f32_dot_bf16 (sum[registerPairIndex], temp_vec1, temp_vec2);
95- #else
96- const uint16x8_t temp_vec1 = vld1q_u16 (reinterpret_cast <const uint16_t *>(
97- &vec1[registerPairIndex * 2 * kF32ElementsPerRegister ]));
98- const uint16x8_t temp_vec2 = vld1q_u16 (reinterpret_cast <const uint16_t *>(
99- &vec2[registerPairIndex * 2 * kF32ElementsPerRegister ]));
100-
101- sum[2 * registerPairIndex] = f32_fma_bf16 (
102- sum[2 * registerPairIndex],
103- vget_low_u16 (temp_vec1),
104- vget_low_u16 (temp_vec2));
105- sum[2 * registerPairIndex + 1 ] = f32_fma_bf16 (
106- sum[2 * registerPairIndex + 1 ],
107- vget_high_u16 (temp_vec1),
108- vget_high_u16 (temp_vec2));
109- #endif
91+ if (useBfloat16Dot) {
92+ const bfloat16x8_t temp_vec1 = vld1q_bf16 (reinterpret_cast <const __bf16*>(
93+ &vec1[registerPairIndex * 2 * kF32ElementsPerRegister ]));
94+ const bfloat16x8_t temp_vec2 = vld1q_bf16 (reinterpret_cast <const __bf16*>(
95+ &vec2[registerPairIndex * 2 * kF32ElementsPerRegister ]));
96+ sum[registerPairIndex] =
97+ f32_dot_bf16 (sum[registerPairIndex], temp_vec1, temp_vec2);
98+ } else
99+ #endif // __ARM_FEATURE_BF16
100+ {
101+ const uint16x8_t temp_vec1 = vld1q_u16 (reinterpret_cast <const uint16_t *>(
102+ &vec1[registerPairIndex * 2 * kF32ElementsPerRegister ]));
103+ const uint16x8_t temp_vec2 = vld1q_u16 (reinterpret_cast <const uint16_t *>(
104+ &vec2[registerPairIndex * 2 * kF32ElementsPerRegister ]));
105+
106+ sum[2 * registerPairIndex] = f32_fma_bf16 (
107+ sum[2 * registerPairIndex],
108+ vget_low_u16 (temp_vec1),
109+ vget_low_u16 (temp_vec2));
110+ sum[2 * registerPairIndex + 1 ] = f32_fma_bf16 (
111+ sum[2 * registerPairIndex + 1 ],
112+ vget_high_u16 (temp_vec1),
113+ vget_high_u16 (temp_vec2));
114+ }
110115}
111116
112117static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop (
@@ -121,7 +126,7 @@ static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
121126 *tailSum = f32_fma_bf16 (*tailSum, temp_vec1, temp_vec2);
122127}
123128
124- template <typename T>
129+ template <typename T, bool useBfloat16Dot >
125130float dot_with_fp32_arith (const T* vec1, const T* vec2, int64_t len) {
126131 float32x4_t sum[kF32RegistersPerIteration ] = {vdupq_n_f32 (0 )};
127132 const auto len_aligned = len & ~(kF32ElementsPerIteration - 1 );
@@ -130,7 +135,8 @@ float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
130135 const auto * vec2_ = vec2 + j;
131136 utils::ForcedUnroll<kF32RegisterPairsPerIteration >{}(
132137 [vec1_, vec2_, &sum](auto k) ET_INLINE_ATTRIBUTE {
133- dot_with_fp32_arith_main_inner_loop (vec1_, vec2_, sum, k);
138+ dot_with_fp32_arith_main_inner_loop<useBfloat16Dot>(
139+ vec1_, vec2_, sum, k);
134140 });
135141 }
136142 auto reducedSum = reduce (sum);
@@ -157,9 +163,16 @@ float bf16_dot_with_fp32_arith(
157163 const BFloat16* vec1,
158164 const BFloat16* vec2,
159165 int64_t len) {
160- return dot_with_fp32_arith (vec1, vec2, len);
166+ #ifdef __ARM_FEATURE_BF16
167+ if (cpuinfo_has_arm_bf16 ()) {
168+ return dot_with_fp32_arith<BFloat16, true >(vec1, vec2, len);
169+ } else
170+ #endif // __ARM_FEATURE_BF16
171+ {
172+ return dot_with_fp32_arith<BFloat16, false >(vec1, vec2, len);
173+ }
161174}
162- #endif
175+ #endif // __aarch64__
163176} // namespace internal
164177} // namespace cpublas
165178} // namespace executorch
0 commit comments