@@ -184,62 +184,19 @@ void kernel_mul(int *result) {
184184 *result = r;
185185}
186186
187- template <typename T>
188- std::array<T,2 > complex_mul (std::array<T, 2 > a, std::array<T, 2 > b){
189- std::array<T, 2 > result;
190- result[0 ] = (a[0 ] * b[0 ]) - (a[1 ] * b[1 ]);
191- result[1 ] = (a[0 ] * b[1 ]) + (a[1 ] * b[0 ]);
192- return result;
193- }
194-
195- template <typename T>
196- std::array<T,2 > complex_add (std::array<T, 2 > a, std::array<T, 2 > b){
197- return {a[0 ] + b[0 ], a[1 ] + b[1 ]};
198- }
199-
200- template <typename T> void mul_add_groundtruth () {
201-
202- using complex_t = std::complex <double >;
203- using arr_t = std::array<T, 2 >;
204-
205- arr_t d1 = arr_t ({static_cast <T>(5.4 ), static_cast <T>(-6.3 )});
206- arr_t d2 = arr_t ({static_cast <T>(-7.2 ), static_cast <T>(8.1 )});
207- arr_t d3 = arr_t ({static_cast <T>(1.0 ), static_cast <T>(-1.0 )});
208-
209- arr_t f1 = arr_t ({static_cast <T>(1.8 ), static_cast <T>(-2.7 )});
210- arr_t f2 = arr_t ({static_cast <T>(-3.6 ), static_cast <T>(4.5 )});
211- arr_t f3 = arr_t ({static_cast <T>(1.0 ), static_cast <T>(-1.0 )});
212-
213- arr_t ra1 = complex_add (complex_mul (d1, d2), d3);
214- arr_t ra2 = complex_add (complex_mul (f1, f2), f3);
215-
216- T expect[4 ] = {13.150000 , 88.100000 , 6.670001 , 16.820000 };
217-
218- // complex_t r1 = d1 * d2 + d3;
219- // complex_t r2 = f1 * f2 + f3;
220-
221- std::cout << " r1: " << static_cast <T>(ra1[0 ]) << " , "
222- << static_cast <T>(ra1[1 ]) << std::endl;
223- std::cout << " Expect 1: " << expect[0 ] << " , " << expect[1 ] << std::endl;
224- std::cout << " r2: " << static_cast <T>(ra2[0 ]) << " , "
225- << static_cast <T>(ra2[1 ]) << std::endl;
226- std::cout << " Expect 2: " << expect[2 ] << " , " << expect[3 ] << std::endl;
227- }
228-
229- template <typename T>
230187void kernel_mul_add (int *result) {
231- sycl::vec<T, 2 > d1, d2, d3;
232- sycl::vec<T, 2 > f1, f2, f3;
233- sycl::marray<T , 2 > m_d1, m_d2, m_d3;
234- sycl::marray<T , 2 > m_f1, m_f2, m_f3;
188+ sycl::double2 d1, d2, d3;
189+ sycl::float2 f1, f2, f3;
190+ sycl::marray<double , 2 > m_d1, m_d2, m_d3;
191+ sycl::marray<float , 2 > m_f1, m_f2, m_f3;
235192
236- d1 = sycl::vec<T, 2 > (5.4 , -6.3 );
237- d2 = sycl::vec<T, 2 > (-7.2 , 8.1 );
238- d3 = sycl::vec<T, 2 > (1.0 , -1.0 );
193+ d1 = sycl::double2 (5.4 , -6.3 );
194+ d2 = sycl::double2 (-7.2 , 8.1 );
195+ d3 = sycl::double2 (1.0 , -1.0 );
239196
240- f1 = sycl::vec<T, 2 > (1.8 , -2.7 );
241- f2 = sycl::vec<T, 2 > (-3.6 , 4.5 );
242- f3 = sycl::vec<T, 2 > (1.0 , -1.0 );
197+ f1 = sycl::float2 (1.8 , -2.7 );
198+ f2 = sycl::float2 (-3.6 , 4.5 );
199+ f3 = sycl::float2 (1.0 , -1.0 );
243200
244201 bool r = true ;
245202 float expect[4 ] = {13.150000 , 88.100000 , 6.670001 , 16.820000 };
@@ -250,13 +207,13 @@ void kernel_mul_add(int *result) {
250207 auto a2 = syclcompat::cmul_add (f1, f2, f3);
251208 r = r && check (a2, expect + 2 );
252209
253- m_d1 = sycl::marray<T , 2 >(5.4 , -6.3 );
254- m_d2 = sycl::marray<T , 2 >(-7.2 , 8.1 );
255- m_d3 = sycl::marray<T , 2 >(1.0 , -1.0 );
210+ m_d1 = sycl::marray<double , 2 >(5.4 , -6.3 );
211+ m_d2 = sycl::marray<double , 2 >(-7.2 , 8.1 );
212+ m_d3 = sycl::marray<double , 2 >(1.0 , -1.0 );
256213
257- m_f1 = sycl::marray<T , 2 >(1.8 , -2.7 );
258- m_f2 = sycl::marray<T , 2 >(-3.6 , 4.5 );
259- m_f3 = sycl::marray<T , 2 >(1.0 , -1.0 );
214+ m_f1 = sycl::marray<float , 2 >(1.8 , -2.7 );
215+ m_f2 = sycl::marray<float , 2 >(-3.6 , 4.5 );
216+ m_f3 = sycl::marray<float , 2 >(1.0 , -1.0 );
260217
261218 auto a3 = syclcompat::cmul_add (m_d1, m_d2, m_d3);
262219 r = r && check (a3, expect);
@@ -284,19 +241,17 @@ void test_conj() {
284241 ComplexLauncher<kernel_conj>().launch ();
285242}
286243
287- template <typename T>
288244void test_mul_add () {
289245 std::cout << __PRETTY_FUNCTION__ << std::endl;
290- mul_add_groundtruth<T>();
291- ComplexLauncher<kernel_mul_add<T>>().launch ();
246+ ComplexLauncher<kernel_mul_add>().launch ();
292247}
293248
294249int main () {
295250 test_abs ();
296251 test_mul ();
297252 test_div ();
298253 test_conj ();
299- INSTANTIATE_ALL_TYPES (fp_type_list, test_mul_add);
254+ test_mul_add ( );
300255
301256 return 0 ;
302257}
0 commit comments