@@ -55,12 +55,12 @@ static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib,
5555#ifdef GGML_SYCL_F16
5656 // v = v * {d, d};
5757 // v = v + {m, m};
58- v.s0 () = (v.s0 () * d) + m ;
59- v.s1 () = (v.s1 () * d) + m ;
58+ v.s0 () = sycl::fma (v.s0 (), d, m) ;
59+ v.s1 () = sycl::fma (v.s1 (), d, m) ;
6060
6161#else
62- v.x () = (v.x () * d) + m ;
63- v.y () = (v.y () * d) + m ;
62+ v.x () = sycl::fma (v.x (), d, m) ;
63+ v.y () = sycl::fma (v.y (), d, m) ;
6464#endif // GGML_SYCL_F16
6565}
6666
@@ -110,11 +110,11 @@ static __dpct_inline__ void dequantize_q5_1(const void *vx, const int64_t ib,
110110#ifdef GGML_SYCL_F16
111111 // v = v * {d, d};
112112 // v = v + {m, m};
113- v.s0 () = (v.s0 () * d) + m ;
114- v.s1 () = (v.s1 () * d) + m ;
113+ v.s0 () = sycl::fma (v.s0 (), d, m) ;
114+ v.s1 () = sycl::fma (v.s1 (), d, m) ;
115115#else
116- v.x () = (v.x () * d) + m ;
117- v.y () = (v.y () * d) + m ;
116+ v.x () = sycl::fma (v.x (), d, m) ;
117+ v.y () = sycl::fma (v.y (), d, m) ;
118118#endif // GGML_SYCL_F16
119119}
120120
0 commit comments