Skip to content

Commit dab8130

Browse files
pytorchbotmalfet
andauthored
[vec128] Fix fmsub NEON defintion (pytorch#153093)
[vec128] Fix fmsub NEON defintion (pytorch#152075) As reported in pytorch#149292, according to manual, `vfmsq_f32` implements `c - a * b` rather than `a * b - c`, so it's call must be prefixed with `vnegq_f32` Also, adjust the tests to use OpMath for FMA computation to avoid accuracy error accumulation due to non-fused multiply-and-add over lower precision dtypes Note that `Vectorized::fmsub` is not currently instantiated anywhere, so it could safely remain broken TODO: - Enable C++ testing on MacOS and/or aarch64 platforms (right now Mac tests are build without C++ tests) Fixes pytorch#149292 Pull Request resolved: pytorch#152075 Approved by: https://github.com/swolchok ghstack dependencies: pytorch#151955 (cherry picked from commit 2ea8653) Co-authored-by: Nikita Shulga <[email protected]>
1 parent 20d62a8 commit dab8130

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ Vectorized<float> inline fmadd(const Vectorized<float>& a, const Vectorized<floa
540540

541541
template <>
542542
Vectorized<float> inline fmsub(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) {
543-
return Vectorized<float>(vfmsq_f32(c, a, b));
543+
return Vectorized<float>(vnegq_f32(vfmsq_f32(c, a, b)));
544544
}
545545

546546
inline Vectorized<float> Vectorized<float>::erf() const{

aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ Vectorized<c10::Half> inline fmsub(
582582
const Vectorized<c10::Half>& b,
583583
const Vectorized<c10::Half>& c) {
584584
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
585-
return Vectorized<c10::Half>(vfmsq_f16(c, a, b));
585+
return Vectorized<c10::Half>(vnegq_f16(vfmsq_f16(c, a, b)));
586586
#else
587587
return a * b - c;
588588
#endif

aten/src/ATen/test/vec_test_all_types.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,16 @@ CACHE_ALIGN #define
6464
#undef CHECK_WITH_FMA
6565
#endif
6666

67+
template <typename scalar_t>
68+
struct OpMathType {
69+
using type = scalar_t;
70+
};
71+
template <>
72+
struct OpMathType<c10::Half> {
73+
using type = float;
74+
};
75+
76+
6777
template<typename T>
6878
using Complex = typename c10::complex<T>;
6979

@@ -1279,15 +1289,17 @@ std::enable_if_t<is_complex<Complex<T>>::value, Complex<T>> local_division(Compl
12791289
template <typename T>
12801290
std::enable_if_t<!is_complex<T>::value, T> local_fmadd(T a, T b, T c) {
12811291
PreventFma noFma;
1282-
T ab = a * b;
1283-
return noFma.add(ab, c);
1292+
using op_math_t = typename OpMathType<T>::type;
1293+
auto ab = static_cast<op_math_t>(a) * static_cast<op_math_t>(b);
1294+
return static_cast<T>(noFma.add(ab, op_math_t(c)));
12841295
}
12851296

12861297
template <typename T>
12871298
std::enable_if_t<!is_complex<T>::value, T> local_fmsub(T a, T b, T c) {
12881299
PreventFma noFma;
1289-
T ab = a * b;
1290-
return noFma.sub(ab, c);
1300+
using op_math_t = typename OpMathType<T>::type;
1301+
auto ab = static_cast<op_math_t>(a) * static_cast<op_math_t>(b);
1302+
return static_cast<T>(noFma.sub(ab, op_math_t(c)));
12911303
}
12921304

12931305
template <typename T>

0 commit comments

Comments
 (0)