@@ -74,43 +74,60 @@ f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) {
7474 return f32_fma (a, to_bfloat16 (b), to_bfloat16 (c));
7575}
7676
77- #ifdef __ARM_FEATURE_BF16
78- static ET_INLINE float32x4_t
77+ #define ET_TARGET_ARM_BF16_ATTRIBUTE \
78+ __attribute__ ((target(" arch=armv8.2-a+bf16" )))
79+ ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE float32x4_t
7980f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) {
8081 return vbfdotq_f32 (a, b, c);
8182}
82- #endif // __ARM_FEATURE_BF16
8383
84- template <bool useBfloat16Dot>
85- static ET_INLINE void dot_with_fp32_arith_main_inner_loop (
84+ ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void
85+ dot_with_fp32_arith_main_inner_loop_bfdot (
86+ const BFloat16* vec1,
87+ const BFloat16* vec2,
88+ float32x4_t sum[kF32RegistersPerIteration ],
89+ int registerPairIndex) {
90+ const bfloat16x8_t temp_vec1 = vld1q_bf16 (reinterpret_cast <const __bf16*>(
91+ &vec1[registerPairIndex * 2 * kF32ElementsPerRegister ]));
92+ const bfloat16x8_t temp_vec2 = vld1q_bf16 (reinterpret_cast <const __bf16*>(
93+ &vec2[registerPairIndex * 2 * kF32ElementsPerRegister ]));
94+ sum[registerPairIndex] =
95+ f32_dot_bf16 (sum[registerPairIndex], temp_vec1, temp_vec2);
96+ }
97+
98+ static ET_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot (
99+ const BFloat16* vec1,
100+ const BFloat16* vec2,
101+ float32x4_t sum[kF32RegistersPerIteration ],
102+ int registerPairIndex) {
103+ const uint16x8_t temp_vec1 = vld1q_u16 (reinterpret_cast <const uint16_t *>(
104+ &vec1[registerPairIndex * 2 * kF32ElementsPerRegister ]));
105+ const uint16x8_t temp_vec2 = vld1q_u16 (reinterpret_cast <const uint16_t *>(
106+ &vec2[registerPairIndex * 2 * kF32ElementsPerRegister ]));
107+
108+ sum[2 * registerPairIndex] = f32_fma_bf16 (
109+ sum[2 * registerPairIndex],
110+ vget_low_u16 (temp_vec1),
111+ vget_low_u16 (temp_vec2));
112+ sum[2 * registerPairIndex + 1 ] = f32_fma_bf16 (
113+ sum[2 * registerPairIndex + 1 ],
114+ vget_high_u16 (temp_vec1),
115+ vget_high_u16 (temp_vec2));
116+ }
117+
118+ template <bool useBfdot>
119+ ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void
120+ dot_with_fp32_arith_main_inner_loop (
86121 const BFloat16* vec1,
87122 const BFloat16* vec2,
88123 float32x4_t sum[kF32RegistersPerIteration ],
89124 int registerPairIndex) {
90- #ifdef __ARM_FEATURE_BF16
91- if (useBfloat16Dot) {
92- const bfloat16x8_t temp_vec1 = vld1q_bf16 (reinterpret_cast <const __bf16*>(
93- &vec1[registerPairIndex * 2 * kF32ElementsPerRegister ]));
94- const bfloat16x8_t temp_vec2 = vld1q_bf16 (reinterpret_cast <const __bf16*>(
95- &vec2[registerPairIndex * 2 * kF32ElementsPerRegister ]));
96- sum[registerPairIndex] =
97- f32_dot_bf16 (sum[registerPairIndex], temp_vec1, temp_vec2);
98- } else
99- #endif // __ARM_FEATURE_BF16
100- {
101- const uint16x8_t temp_vec1 = vld1q_u16 (reinterpret_cast <const uint16_t *>(
102- &vec1[registerPairIndex * 2 * kF32ElementsPerRegister ]));
103- const uint16x8_t temp_vec2 = vld1q_u16 (reinterpret_cast <const uint16_t *>(
104- &vec2[registerPairIndex * 2 * kF32ElementsPerRegister ]));
105-
106- sum[2 * registerPairIndex] = f32_fma_bf16 (
107- sum[2 * registerPairIndex],
108- vget_low_u16 (temp_vec1),
109- vget_low_u16 (temp_vec2));
110- sum[2 * registerPairIndex + 1 ] = f32_fma_bf16 (
111- sum[2 * registerPairIndex + 1 ],
112- vget_high_u16 (temp_vec1),
113- vget_high_u16 (temp_vec2));
125+ if constexpr (useBfdot) {
126+ dot_with_fp32_arith_main_inner_loop_bfdot (
127+ vec1, vec2, sum, registerPairIndex);
128+ } else {
129+ dot_with_fp32_arith_main_inner_loop_no_bfdot (
130+ vec1, vec2, sum, registerPairIndex);
114131 }
115132}
116133
@@ -126,18 +143,40 @@ static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
126143 *tailSum = f32_fma_bf16 (*tailSum, temp_vec1, temp_vec2);
127144}
128145
129- template <typename T, bool useBfloat16Dot>
130- float dot_with_fp32_arith (const T* vec1, const T* vec2, int64_t len) {
146+ namespace {
147+ template <int n>
148+ struct ForcedUnrollTargetBFloat16 {
149+ template <typename Func>
150+ ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator ()(const Func& f) const {
151+ ForcedUnrollTargetBFloat16<n - 1 >{}(f);
152+ f (n - 1 );
153+ }
154+ };
155+
156+ template <>
157+ struct ForcedUnrollTargetBFloat16 <1 > {
158+ template <typename Func>
159+ ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator ()(const Func& f) const {
160+ f (0 );
161+ }
162+ };
163+
164+ } // namespace
165+
166+ template <typename T, bool useBFloat16Dot>
167+ ET_TARGET_ARM_BF16_ATTRIBUTE float
168+ dot_with_fp32_arith (const T* vec1, const T* vec2, int64_t len) {
131169 float32x4_t sum[kF32RegistersPerIteration ] = {vdupq_n_f32 (0 )};
132170 const auto len_aligned = len & ~(kF32ElementsPerIteration - 1 );
133171 for (int j = 0 ; j < len_aligned; j += kF32ElementsPerIteration ) {
134172 const auto * vec1_ = vec1 + j;
135173 const auto * vec2_ = vec2 + j;
136- utils::ForcedUnroll<kF32RegisterPairsPerIteration >{}(
137- [vec1_, vec2_, &sum](auto k) ET_INLINE_ATTRIBUTE {
138- dot_with_fp32_arith_main_inner_loop<useBfloat16Dot>(
139- vec1_, vec2_, sum, k);
140- });
174+ ForcedUnrollTargetBFloat16<kF32RegisterPairsPerIteration >{}(
175+ [vec1_, vec2_, &sum](auto k)
176+ ET_INLINE_ATTRIBUTE ET_TARGET_ARM_BF16_ATTRIBUTE {
177+ dot_with_fp32_arith_main_inner_loop<useBFloat16Dot>(
178+ vec1_, vec2_, sum, k);
179+ });
141180 }
142181 auto reducedSum = reduce (sum);
143182
@@ -163,12 +202,9 @@ float bf16_dot_with_fp32_arith(
163202 const BFloat16* vec1,
164203 const BFloat16* vec2,
165204 int64_t len) {
166- #ifdef __ARM_FEATURE_BF16
167205 if (cpuinfo_has_arm_bf16 ()) {
168206 return dot_with_fp32_arith<BFloat16, true >(vec1, vec2, len);
169- } else
170- #endif // __ARM_FEATURE_BF16
171- {
207+ } else {
172208 return dot_with_fp32_arith<BFloat16, false >(vec1, vec2, len);
173209 }
174210}
0 commit comments