1010
1111#ifdef  __aarch64__
1212#include  < arm_neon.h> 
13+ #include  < cpuinfo.h> 
1314#endif 
1415
1516using  torch::executor::BFloat16;
@@ -23,7 +24,7 @@ static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) {
2324  return  vfmaq_f32 (a, b, c);
2425#else 
2526  return  vaddq_f32 (a, vmulq_f32 (b, c));
26- #endif 
27+ #endif   //  __ARM_FEATURE_FMA 
2728}
2829
2930//  The below reduce overload and fp16_dot_with_fp32_arith are adapted
@@ -78,35 +79,39 @@ static ET_INLINE float32x4_t
7879f32_dot_bf16 (float32x4_t  a, bfloat16x8_t  b, bfloat16x8_t  c) {
7980  return  vbfdotq_f32 (a, b, c);
8081}
81- #endif 
82+ #endif   //  __ARM_FEATURE_BF16 
8283
84+ template  <bool  useBfloat16Dot>
8385static  ET_INLINE void  dot_with_fp32_arith_main_inner_loop (
8486    const  BFloat16* vec1,
8587    const  BFloat16* vec2,
8688    float32x4_t  sum[kF32RegistersPerIteration ],
8789    int  registerPairIndex) {
8890#ifdef  __ARM_FEATURE_BF16
89-   const  bfloat16x8_t  temp_vec1 = vld1q_bf16 (reinterpret_cast <const  __bf16*>(
90-       &vec1[registerPairIndex * 2  * kF32ElementsPerRegister ]));
91-   const  bfloat16x8_t  temp_vec2 = vld1q_bf16 (reinterpret_cast <const  __bf16*>(
92-       &vec2[registerPairIndex * 2  * kF32ElementsPerRegister ]));
93-   sum[registerPairIndex] =
94-       f32_dot_bf16 (sum[registerPairIndex], temp_vec1, temp_vec2);
95- #else 
96-   const  uint16x8_t  temp_vec1 = vld1q_u16 (reinterpret_cast <const  uint16_t *>(
97-       &vec1[registerPairIndex * 2  * kF32ElementsPerRegister ]));
98-   const  uint16x8_t  temp_vec2 = vld1q_u16 (reinterpret_cast <const  uint16_t *>(
99-       &vec2[registerPairIndex * 2  * kF32ElementsPerRegister ]));
100- 
101-   sum[2  * registerPairIndex] = f32_fma_bf16 (
102-       sum[2  * registerPairIndex],
103-       vget_low_u16 (temp_vec1),
104-       vget_low_u16 (temp_vec2));
105-   sum[2  * registerPairIndex + 1 ] = f32_fma_bf16 (
106-       sum[2  * registerPairIndex + 1 ],
107-       vget_high_u16 (temp_vec1),
108-       vget_high_u16 (temp_vec2));
109- #endif 
91+   if  (useBfloat16Dot) {
92+     const  bfloat16x8_t  temp_vec1 = vld1q_bf16 (reinterpret_cast <const  __bf16*>(
93+         &vec1[registerPairIndex * 2  * kF32ElementsPerRegister ]));
94+     const  bfloat16x8_t  temp_vec2 = vld1q_bf16 (reinterpret_cast <const  __bf16*>(
95+         &vec2[registerPairIndex * 2  * kF32ElementsPerRegister ]));
96+     sum[registerPairIndex] =
97+         f32_dot_bf16 (sum[registerPairIndex], temp_vec1, temp_vec2);
98+   } else 
99+ #endif  //  __ARM_FEATURE_BF16
100+   {
101+     const  uint16x8_t  temp_vec1 = vld1q_u16 (reinterpret_cast <const  uint16_t *>(
102+         &vec1[registerPairIndex * 2  * kF32ElementsPerRegister ]));
103+     const  uint16x8_t  temp_vec2 = vld1q_u16 (reinterpret_cast <const  uint16_t *>(
104+         &vec2[registerPairIndex * 2  * kF32ElementsPerRegister ]));
105+ 
106+     sum[2  * registerPairIndex] = f32_fma_bf16 (
107+         sum[2  * registerPairIndex],
108+         vget_low_u16 (temp_vec1),
109+         vget_low_u16 (temp_vec2));
110+     sum[2  * registerPairIndex + 1 ] = f32_fma_bf16 (
111+         sum[2  * registerPairIndex + 1 ],
112+         vget_high_u16 (temp_vec1),
113+         vget_high_u16 (temp_vec2));
114+   }
110115}
111116
112117static  ET_INLINE void  dot_with_fp32_arith_vectorized_tail_inner_loop (
@@ -121,7 +126,7 @@ static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
121126  *tailSum = f32_fma_bf16 (*tailSum, temp_vec1, temp_vec2);
122127}
123128
124- template  <typename  T>
129+ template  <typename  T,  bool  useBfloat16Dot >
125130float  dot_with_fp32_arith (const  T* vec1, const  T* vec2, int64_t  len) {
126131  float32x4_t  sum[kF32RegistersPerIteration ] = {vdupq_n_f32 (0 )};
127132  const  auto  len_aligned = len & ~(kF32ElementsPerIteration  - 1 );
@@ -130,7 +135,8 @@ float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
130135    const  auto * vec2_ = vec2 + j;
131136    utils::ForcedUnroll<kF32RegisterPairsPerIteration >{}(
132137        [vec1_, vec2_, &sum](auto  k) ET_INLINE_ATTRIBUTE {
133-           dot_with_fp32_arith_main_inner_loop (vec1_, vec2_, sum, k);
138+           dot_with_fp32_arith_main_inner_loop<useBfloat16Dot>(
139+               vec1_, vec2_, sum, k);
134140        });
135141  }
136142  auto  reducedSum = reduce (sum);
@@ -157,9 +163,16 @@ float bf16_dot_with_fp32_arith(
157163    const  BFloat16* vec1,
158164    const  BFloat16* vec2,
159165    int64_t  len) {
160-   return  dot_with_fp32_arith (vec1, vec2, len);
166+ #ifdef  __ARM_FEATURE_BF16
167+   if  (cpuinfo_has_arm_bf16 ()) {
168+     return  dot_with_fp32_arith<BFloat16, true >(vec1, vec2, len);
169+   } else 
170+ #endif  //  __ARM_FEATURE_BF16
171+   {
172+     return  dot_with_fp32_arith<BFloat16, false >(vec1, vec2, len);
173+   }
161174}
162- #endif 
175+ #endif   //  __aarch64__ 
163176} //  namespace internal
164177} //  namespace cpublas
165178} //  namespace executorch
0 commit comments