@@ -184,19 +184,62 @@ void kernel_mul(int *result) {
184184 *result = r;
185185}
186186
187+ template <typename T>
188+ std::array<T,2 > complex_mul (std::array<T, 2 > a, std::array<T, 2 > b){
189+ std::array<T, 2 > result;
190+ result[0 ] = (a[0 ] * b[0 ]) - (a[1 ] * b[1 ]);
191+ result[1 ] = (a[0 ] * b[1 ]) + (a[1 ] * b[0 ]);
192+ return result;
193+ }
194+
195+ template <typename T>
196+ std::array<T,2 > complex_add (std::array<T, 2 > a, std::array<T, 2 > b){
197+ return {a[0 ] + b[0 ], a[1 ] + b[1 ]};
198+ }
199+
200+ template <typename T> void mul_add_groundtruth () {
201+
202+ using complex_t = std::complex <double >;
203+ using arr_t = std::array<T, 2 >;
204+
205+ arr_t d1 = arr_t ({static_cast <T>(5.4 ), static_cast <T>(-6.3 )});
206+ arr_t d2 = arr_t ({static_cast <T>(-7.2 ), static_cast <T>(8.1 )});
207+ arr_t d3 = arr_t ({static_cast <T>(1.0 ), static_cast <T>(-1.0 )});
208+
209+ arr_t f1 = arr_t ({static_cast <T>(1.8 ), static_cast <T>(-2.7 )});
210+ arr_t f2 = arr_t ({static_cast <T>(-3.6 ), static_cast <T>(4.5 )});
211+ arr_t f3 = arr_t ({static_cast <T>(1.0 ), static_cast <T>(-1.0 )});
212+
213+ arr_t ra1 = complex_add (complex_mul (d1, d2), d3);
214+ arr_t ra2 = complex_add (complex_mul (f1, f2), f3);
215+
216+ T expect[4 ] = {13.150000 , 88.100000 , 6.670001 , 16.820000 };
217+
218+ // complex_t r1 = d1 * d2 + d3;
219+ // complex_t r2 = f1 * f2 + f3;
220+
221+ std::cout << " r1: " << static_cast <T>(ra1[0 ]) << " , "
222+ << static_cast <T>(ra1[1 ]) << std::endl;
223+ std::cout << " Expect 1: " << expect[0 ] << " , " << expect[1 ] << std::endl;
224+ std::cout << " r2: " << static_cast <T>(ra2[0 ]) << " , "
225+ << static_cast <T>(ra2[1 ]) << std::endl;
226+ std::cout << " Expect 2: " << expect[2 ] << " , " << expect[3 ] << std::endl;
227+ }
228+
229+ template <typename T>
187230void kernel_mul_add (int *result) {
188- sycl::double2 d1, d2, d3;
189- sycl::float2 f1, f2, f3;
190- sycl::marray<double , 2 > m_d1, m_d2, m_d3;
191- sycl::marray<float , 2 > m_f1, m_f2, m_f3;
231+ sycl::vec<T, 2 > d1, d2, d3;
232+ sycl::vec<T, 2 > f1, f2, f3;
233+ sycl::marray<T , 2 > m_d1, m_d2, m_d3;
234+ sycl::marray<T , 2 > m_f1, m_f2, m_f3;
192235
193- d1 = sycl::double2 (5.4 , -6.3 );
194- d2 = sycl::double2 (-7.2 , 8.1 );
195- d3 = sycl::double2 (1.0 , -1.0 );
236+ d1 = sycl::vec<T, 2 > (5.4 , -6.3 );
237+ d2 = sycl::vec<T, 2 > (-7.2 , 8.1 );
238+ d3 = sycl::vec<T, 2 > (1.0 , -1.0 );
196239
197- f1 = sycl::float2 (1.8 , -2.7 );
198- f2 = sycl::float2 (-3.6 , 4.5 );
199- f3 = sycl::float2 (1.0 , -1.0 );
240+ f1 = sycl::vec<T, 2 > (1.8 , -2.7 );
241+ f2 = sycl::vec<T, 2 > (-3.6 , 4.5 );
242+ f3 = sycl::vec<T, 2 > (1.0 , -1.0 );
200243
201244 bool r = true ;
202245 float expect[4 ] = {13.150000 , 88.100000 , 6.670001 , 16.820000 };
@@ -207,13 +250,13 @@ void kernel_mul_add(int *result) {
207250 auto a2 = syclcompat::cmul_add (f1, f2, f3);
208251 r = r && check (a2, expect + 2 );
209252
210- m_d1 = sycl::marray<double , 2 >(5.4 , -6.3 );
211- m_d2 = sycl::marray<double , 2 >(-7.2 , 8.1 );
212- m_d3 = sycl::marray<double , 2 >(1.0 , -1.0 );
253+ m_d1 = sycl::marray<T , 2 >(5.4 , -6.3 );
254+ m_d2 = sycl::marray<T , 2 >(-7.2 , 8.1 );
255+ m_d3 = sycl::marray<T , 2 >(1.0 , -1.0 );
213256
214- m_f1 = sycl::marray<float , 2 >(1.8 , -2.7 );
215- m_f2 = sycl::marray<float , 2 >(-3.6 , 4.5 );
216- m_f3 = sycl::marray<float , 2 >(1.0 , -1.0 );
257+ m_f1 = sycl::marray<T , 2 >(1.8 , -2.7 );
258+ m_f2 = sycl::marray<T , 2 >(-3.6 , 4.5 );
259+ m_f3 = sycl::marray<T , 2 >(1.0 , -1.0 );
217260
218261 auto a3 = syclcompat::cmul_add (m_d1, m_d2, m_d3);
219262 r = r && check (a3, expect);
@@ -241,17 +284,19 @@ void test_conj() {
241284 ComplexLauncher<kernel_conj>().launch ();
242285}
243286
287+ template <typename T>
244288void test_mul_add () {
245289 std::cout << __PRETTY_FUNCTION__ << std::endl;
246- ComplexLauncher<kernel_mul_add>().launch ();
290+ mul_add_groundtruth<T>();
291+ ComplexLauncher<kernel_mul_add<T>>().launch ();
247292}
248293
249294int main () {
250295 test_abs ();
251296 test_mul ();
252297 test_div ();
253298 test_conj ();
254- test_mul_add ( );
299+ INSTANTIATE_ALL_TYPES (fp_type_list, test_mul_add );
255300
256301 return 0 ;
257302}
0 commit comments