Skip to content

Commit 9104c45

Browse files
authored
[SYCLomatic #2139] Add helper function test for dpct::blas::gels_batch_wrapper (#737)
Signed-off-by: Jiang, Zhiwei <[email protected]>
1 parent 666c71f commit 9104c45

File tree

3 files changed

+357
-1
lines changed

3 files changed

+357
-1
lines changed

help_function/help_function.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,5 +216,6 @@
216216
<test testName="atomic_compare_exchange_strong" configFile="config/TEMPLATE_help_function.xml" />
217217
<test testName="filter_device" configFile="config/TEMPLATE_help_function.xml" />
218218
<test testName="blas_gemm_utils_interface" configFile="config/TEMPLATE_help_function_skip_cuda_backend.xml" />
219+
<test testName="blas_utils_gels-usm" configFile="config/TEMPLATE_help_function_blas_usm.xml" splitGroup="double"/>
219220
</tests>
220221
</suite>
Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
// ===------ blas_utils_gels-usm.cpp ---------------------- *- C++ -* ----=== //
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
// ===--------------------------------------------------------------------=== //
8+
9+
#include <sycl/sycl.hpp>
10+
#include <dpct/dpct.hpp>
11+
#include <dpct/blas_utils.hpp>
12+
13+
#include <cmath>
14+
#include <cstdio>
15+
16+
bool all_pass = true;
17+
18+
void test1() {
19+
dpct::device_ext &dev_ct1 = dpct::get_current_device();
20+
sycl::queue &q_ct1 = dev_ct1.in_order_queue();
21+
float A[9] = {2, 3, 5, 7, 11, 13, 17, 19, 23};
22+
float B[9] = {1, 2, 3, 4, 5, 6, 7, 9, 9};
23+
24+
float *A_dev_mem;
25+
float *B_dev_mem;
26+
A_dev_mem = sycl::malloc_device<float>(18, q_ct1);
27+
B_dev_mem = sycl::malloc_device<float>(18, q_ct1);
28+
q_ct1.memcpy(A_dev_mem, A, sizeof(float) * 9);
29+
q_ct1.memcpy(A_dev_mem + 9, A, sizeof(float) * 9);
30+
q_ct1.memcpy(B_dev_mem, B, sizeof(float) * 9);
31+
q_ct1.memcpy(B_dev_mem + 9, B, sizeof(float) * 9).wait();
32+
33+
float **As;
34+
float **Bs;
35+
As = sycl::malloc_device<float *>(2, q_ct1);
36+
Bs = sycl::malloc_device<float *>(2, q_ct1);
37+
38+
q_ct1.memcpy(As, &A_dev_mem, sizeof(float *));
39+
float *temp_a = A_dev_mem + 9;
40+
q_ct1.memcpy(As + 1, &temp_a, sizeof(float *));
41+
q_ct1.memcpy(Bs, &B_dev_mem, sizeof(float *));
42+
float *temp_b = B_dev_mem + 9;
43+
q_ct1.memcpy(Bs + 1, &temp_b, sizeof(float *)).wait();
44+
45+
dpct::blas::descriptor_ptr handle;
46+
handle = new dpct::blas::descriptor();
47+
48+
int info;
49+
dpct::blas::gels_batch_wrapper(handle, oneapi::mkl::transpose::nontrans, 3, 3,
50+
3, As, 3, Bs, 3, &info, NULL, 2);
51+
q_ct1.wait();
52+
53+
float A_host_mem[18];
54+
float B_host_mem[18];
55+
q_ct1.memcpy(A_host_mem, A_dev_mem, sizeof(float) * 18);
56+
q_ct1.memcpy(B_host_mem, B_dev_mem, sizeof(float) * 18).wait();
57+
58+
float A_ref[18] = {-6.164414, 0.367448, 0.612414, -18.168798, -2.982405,
59+
-0.509851, -33.417614, -6.653060, -4.242642, -6.164414,
60+
0.367448, 0.612414, -18.168798, -2.982405, -0.509851,
61+
-33.417614, -6.653060, -4.242642};
62+
float B_ref[18] = {0.461538, 0.166667, -0.064103, 0.000000, 0.166667,
63+
0.166667, -1.230769, 0.666667, 0.282051, 0.461538,
64+
0.166667, -0.064103, 0.000000, 0.166667, 0.166667,
65+
-1.230769, 0.666667, 0.282051};
66+
67+
bool pass = true;
68+
for (int i = 0; i < 18; i++) {
69+
if (std::fabs(A_ref[i] - A_host_mem[i]) > 0.01) {
70+
pass = false;
71+
break;
72+
}
73+
if (std::fabs(B_ref[i] - B_host_mem[i]) > 0.01) {
74+
pass = false;
75+
break;
76+
}
77+
}
78+
79+
if (pass) {
80+
printf("test1 pass\n");
81+
return;
82+
}
83+
printf("test1 fail\n");
84+
printf("a:\n");
85+
for (int i = 0; i < 18; i++) {
86+
printf("%f, ", A_host_mem[i]);
87+
}
88+
printf("\n");
89+
printf("b:\n");
90+
for (int i = 0; i < 18; i++) {
91+
printf("%f, ", B_host_mem[i]);
92+
}
93+
printf("\n");
94+
all_pass = false;
95+
}
96+
97+
void test2() {
98+
dpct::device_ext &dev_ct1 = dpct::get_current_device();
99+
sycl::queue &q_ct1 = dev_ct1.in_order_queue();
100+
double A[9] = {2, 3, 5, 7, 11, 13, 17, 19, 23};
101+
double B[9] = {1, 2, 3, 4, 5, 6, 7, 9, 9};
102+
103+
double *A_dev_mem;
104+
double *B_dev_mem;
105+
A_dev_mem = sycl::malloc_device<double>(18, q_ct1);
106+
B_dev_mem = sycl::malloc_device<double>(18, q_ct1);
107+
q_ct1.memcpy(A_dev_mem, A, sizeof(double) * 9);
108+
q_ct1.memcpy(A_dev_mem + 9, A, sizeof(double) * 9);
109+
q_ct1.memcpy(B_dev_mem, B, sizeof(double) * 9);
110+
q_ct1.memcpy(B_dev_mem + 9, B, sizeof(double) * 9).wait();
111+
112+
double **As;
113+
double **Bs;
114+
As = sycl::malloc_device<double *>(2, q_ct1);
115+
Bs = sycl::malloc_device<double *>(2, q_ct1);
116+
117+
q_ct1.memcpy(As, &A_dev_mem, sizeof(double *));
118+
double *temp_a = A_dev_mem + 9;
119+
q_ct1.memcpy(As + 1, &temp_a, sizeof(double *));
120+
q_ct1.memcpy(Bs, &B_dev_mem, sizeof(double *));
121+
double *temp_b = B_dev_mem + 9;
122+
q_ct1.memcpy(Bs + 1, &temp_b, sizeof(double *)).wait();
123+
124+
dpct::blas::descriptor_ptr handle;
125+
handle = new dpct::blas::descriptor();
126+
127+
int info;
128+
dpct::blas::gels_batch_wrapper(handle, oneapi::mkl::transpose::nontrans, 3, 3,
129+
3, As, 3, Bs, 3, &info, NULL, 2);
130+
q_ct1.wait();
131+
132+
double A_host_mem[18];
133+
double B_host_mem[18];
134+
q_ct1.memcpy(A_host_mem, A_dev_mem, sizeof(double) * 18);
135+
q_ct1.memcpy(B_host_mem, B_dev_mem, sizeof(double) * 18).wait();
136+
137+
double A_ref[18] = {-6.164414, 0.367448, 0.612414, -18.168798, -2.982405,
138+
-0.509851, -33.417614, -6.653060, -4.242642, -6.164414,
139+
0.367448, 0.612414, -18.168798, -2.982405, -0.509851,
140+
-33.417614, -6.653060, -4.242642};
141+
double B_ref[18] = {0.461538, 0.166667, -0.064103, 0.000000, 0.166667,
142+
0.166667, -1.230769, 0.666667, 0.282051, 0.461538,
143+
0.166667, -0.064103, 0.000000, 0.166667, 0.166667,
144+
-1.230769, 0.666667, 0.282051};
145+
146+
bool pass = true;
147+
for (int i = 0; i < 18; i++) {
148+
if (std::fabs(A_ref[i] - A_host_mem[i]) > 0.01) {
149+
pass = false;
150+
break;
151+
}
152+
if (std::fabs(B_ref[i] - B_host_mem[i]) > 0.01) {
153+
pass = false;
154+
break;
155+
}
156+
}
157+
158+
if (pass) {
159+
printf("test2 pass\n");
160+
return;
161+
}
162+
printf("test2 fail\n");
163+
printf("a:\n");
164+
for (int i = 0; i < 18; i++) {
165+
printf("%f, ", A_host_mem[i]);
166+
}
167+
printf("\n");
168+
printf("b:\n");
169+
for (int i = 0; i < 18; i++) {
170+
printf("%f, ", B_host_mem[i]);
171+
}
172+
printf("\n");
173+
all_pass = false;
174+
}
175+
176+
void test3() {
177+
dpct::device_ext &dev_ct1 = dpct::get_current_device();
178+
sycl::queue &q_ct1 = dev_ct1.in_order_queue();
179+
sycl::float2 A[9] = {
180+
sycl::float2(2, 0), sycl::float2(3, 0), sycl::float2(5, 0),
181+
sycl::float2(7, 0), sycl::float2(11, 0), sycl::float2(13, 0),
182+
sycl::float2(17, 0), sycl::float2(19, 0), sycl::float2(23, 0)};
183+
sycl::float2 B[9] = {
184+
sycl::float2(1, 0), sycl::float2(2, 0), sycl::float2(3, 0),
185+
sycl::float2(4, 0), sycl::float2(5, 0), sycl::float2(6, 0),
186+
sycl::float2(7, 0), sycl::float2(9, 0), sycl::float2(9, 0)};
187+
188+
sycl::float2 *A_dev_mem;
189+
sycl::float2 *B_dev_mem;
190+
A_dev_mem = sycl::malloc_device<sycl::float2>(18, q_ct1);
191+
B_dev_mem = sycl::malloc_device<sycl::float2>(18, q_ct1);
192+
q_ct1.memcpy(A_dev_mem, A, sizeof(sycl::float2) * 9);
193+
q_ct1.memcpy(A_dev_mem + 9, A, sizeof(sycl::float2) * 9);
194+
q_ct1.memcpy(B_dev_mem, B, sizeof(sycl::float2) * 9);
195+
q_ct1.memcpy(B_dev_mem + 9, B, sizeof(sycl::float2) * 9).wait();
196+
197+
sycl::float2 **As;
198+
sycl::float2 **Bs;
199+
As = sycl::malloc_device<sycl::float2 *>(2, q_ct1);
200+
Bs = sycl::malloc_device<sycl::float2 *>(2, q_ct1);
201+
202+
q_ct1.memcpy(As, &A_dev_mem, sizeof(sycl::float2 *));
203+
sycl::float2 *temp_a = A_dev_mem + 9;
204+
q_ct1.memcpy(As + 1, &temp_a, sizeof(sycl::float2 *));
205+
q_ct1.memcpy(Bs, &B_dev_mem, sizeof(sycl::float2 *));
206+
sycl::float2 *temp_b = B_dev_mem + 9;
207+
q_ct1.memcpy(Bs + 1, &temp_b, sizeof(sycl::float2 *)).wait();
208+
209+
dpct::blas::descriptor_ptr handle;
210+
handle = new dpct::blas::descriptor();
211+
212+
int info;
213+
dpct::blas::gels_batch_wrapper(handle, oneapi::mkl::transpose::nontrans, 3, 3,
214+
3, As, 3, Bs, 3, &info, NULL, 2);
215+
q_ct1.wait();
216+
217+
sycl::float2 A_host_mem[18];
218+
sycl::float2 B_host_mem[18];
219+
q_ct1.memcpy(A_host_mem, A_dev_mem, sizeof(sycl::float2) * 18);
220+
q_ct1.memcpy(B_host_mem, B_dev_mem, sizeof(sycl::float2) * 18).wait();
221+
222+
float A_ref[18] = {-6.164414, 0.367448, 0.612414, -18.168798, -2.982405,
223+
-0.509851, -33.417614, -6.653060, -4.242642, -6.164414,
224+
0.367448, 0.612414, -18.168798, -2.982405, -0.509851,
225+
-33.417614, -6.653060, -4.242642};
226+
float B_ref[18] = {0.461538, 0.166667, -0.064103, 0.000000, 0.166667,
227+
0.166667, -1.230769, 0.666667, 0.282051, 0.461538,
228+
0.166667, -0.064103, 0.000000, 0.166667, 0.166667,
229+
-1.230769, 0.666667, 0.282051};
230+
231+
bool pass = true;
232+
for (int i = 0; i < 18; i++) {
233+
if (std::fabs(A_ref[i] - A_host_mem[i].x()) > 0.01) {
234+
pass = false;
235+
break;
236+
}
237+
if (std::fabs(B_ref[i] - B_host_mem[i].x()) > 0.01) {
238+
pass = false;
239+
break;
240+
}
241+
}
242+
243+
if (pass) {
244+
printf("test3 pass\n");
245+
return;
246+
}
247+
printf("test3 fail\n");
248+
printf("a:\n");
249+
for (int i = 0; i < 18; i++) {
250+
printf("%f, ", A_host_mem[i].x());
251+
}
252+
printf("\n");
253+
printf("b:\n");
254+
for (int i = 0; i < 18; i++) {
255+
printf("%f, ", B_host_mem[i].x());
256+
}
257+
printf("\n");
258+
all_pass = false;
259+
}
260+
261+
void test4() {
262+
dpct::device_ext &dev_ct1 = dpct::get_current_device();
263+
sycl::queue &q_ct1 = dev_ct1.in_order_queue();
264+
sycl::double2 A[9] = {
265+
sycl::double2(2, 0), sycl::double2(3, 0), sycl::double2(5, 0),
266+
sycl::double2(7, 0), sycl::double2(11, 0), sycl::double2(13, 0),
267+
sycl::double2(17, 0), sycl::double2(19, 0), sycl::double2(23, 0)};
268+
sycl::double2 B[9] = {
269+
sycl::double2(1, 0), sycl::double2(2, 0), sycl::double2(3, 0),
270+
sycl::double2(4, 0), sycl::double2(5, 0), sycl::double2(6, 0),
271+
sycl::double2(7, 0), sycl::double2(9, 0), sycl::double2(9, 0)};
272+
273+
sycl::double2 *A_dev_mem;
274+
sycl::double2 *B_dev_mem;
275+
A_dev_mem = sycl::malloc_device<sycl::double2>(18, q_ct1);
276+
B_dev_mem = sycl::malloc_device<sycl::double2>(18, q_ct1);
277+
q_ct1.memcpy(A_dev_mem, A, sizeof(sycl::double2) * 9);
278+
q_ct1.memcpy(A_dev_mem + 9, A, sizeof(sycl::double2) * 9);
279+
q_ct1.memcpy(B_dev_mem, B, sizeof(sycl::double2) * 9);
280+
q_ct1.memcpy(B_dev_mem + 9, B, sizeof(sycl::double2) * 9).wait();
281+
282+
sycl::double2 **As;
283+
sycl::double2 **Bs;
284+
As = sycl::malloc_device<sycl::double2 *>(2, q_ct1);
285+
Bs = sycl::malloc_device<sycl::double2 *>(2, q_ct1);
286+
287+
q_ct1.memcpy(As, &A_dev_mem, sizeof(sycl::double2 *));
288+
sycl::double2 *temp_a = A_dev_mem + 9;
289+
q_ct1.memcpy(As + 1, &temp_a, sizeof(sycl::double2 *));
290+
q_ct1.memcpy(Bs, &B_dev_mem, sizeof(sycl::double2 *));
291+
sycl::double2 *temp_b = B_dev_mem + 9;
292+
q_ct1.memcpy(Bs + 1, &temp_b, sizeof(sycl::double2 *)).wait();
293+
294+
dpct::blas::descriptor_ptr handle;
295+
handle = new dpct::blas::descriptor();
296+
297+
int info;
298+
dpct::blas::gels_batch_wrapper(handle, oneapi::mkl::transpose::nontrans, 3, 3,
299+
3, As, 3, Bs, 3, &info, NULL, 2);
300+
q_ct1.wait();
301+
302+
sycl::double2 A_host_mem[18];
303+
sycl::double2 B_host_mem[18];
304+
q_ct1.memcpy(A_host_mem, A_dev_mem, sizeof(sycl::double2) * 18);
305+
q_ct1.memcpy(B_host_mem, B_dev_mem, sizeof(sycl::double2) * 18).wait();
306+
307+
double A_ref[18] = {-6.164414, 0.367448, 0.612414, -18.168798, -2.982405,
308+
-0.509851, -33.417614, -6.653060, -4.242642, -6.164414,
309+
0.367448, 0.612414, -18.168798, -2.982405, -0.509851,
310+
-33.417614, -6.653060, -4.242642};
311+
double B_ref[18] = {0.461538, 0.166667, -0.064103, 0.000000, 0.166667,
312+
0.166667, -1.230769, 0.666667, 0.282051, 0.461538,
313+
0.166667, -0.064103, 0.000000, 0.166667, 0.166667,
314+
-1.230769, 0.666667, 0.282051};
315+
316+
bool pass = true;
317+
for (int i = 0; i < 18; i++) {
318+
if (std::fabs(A_ref[i] - A_host_mem[i].x()) > 0.01) {
319+
pass = false;
320+
break;
321+
}
322+
if (std::fabs(B_ref[i] - B_host_mem[i].x()) > 0.01) {
323+
pass = false;
324+
break;
325+
}
326+
}
327+
328+
if (pass) {
329+
printf("test4 pass\n");
330+
return;
331+
}
332+
printf("test4 fail\n");
333+
printf("a:\n");
334+
for (int i = 0; i < 18; i++) {
335+
printf("%f, ", A_host_mem[i].x());
336+
}
337+
printf("\n");
338+
printf("b:\n");
339+
for (int i = 0; i < 18; i++) {
340+
printf("%f, ", B_host_mem[i].x());
341+
}
342+
printf("\n");
343+
all_pass = false;
344+
}
345+
346+
int main() {
347+
test1();
348+
test2();
349+
test3();
350+
test4();
351+
if (all_pass)
352+
return 0;
353+
return 1;
354+
}

help_function/test_help.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def build_test():
3838
"blas_utils_geqrf-complex-usm", "blas_utils_get_transpose", "blas_utils_get_value",
3939
"blas_utils_get_value_usm", "blas_extension_api_buffer", "lib_common_utils_mkl_get_version",
4040
"blas_extension_api_usm", "blas_utils_getrfnp", "blas_utils_getrfnp-usm", "blas_utils_getrfnp-complex",
41-
"blas_utils_getrfnp-complex-usm", "blas_utils_parameter_wrapper_buf", "blas_utils_parameter_wrapper_usm"]
41+
"blas_utils_getrfnp-complex-usm", "blas_utils_parameter_wrapper_buf", "blas_utils_parameter_wrapper_usm",
42+
"blas_utils_gels-usm"]
4243
oneDNN_related = ["dnnl_utils_activation", "dnnl_utils_fill", "dnnl_utils_lrn", "dnnl_utils_memory",
4344
"dnnl_utils_pooling", "dnnl_utils_reorder", "dnnl_utils_scale", "dnnl_utils_softmax",
4445
"dnnl_utils_sum", "dnnl_utils_reduction", "dnnl_utils_binary", "dnnl_utils_batch_normalization_1",

0 commit comments

Comments
 (0)