Skip to content

Commit 3d9ded6

Browse files
authored
[SYCL] Change check_device_code CUDA tests to use SYCL_EXTERNAL (#13943)
Changed CUDA sycl/test/check_device_code lit test cases to use SYCL_EXTERNAL functions instead of submitting kernels to the queue everytime.
1 parent b925bd8 commit 3d9ded6

File tree

9 files changed

+1584
-1590
lines changed

9 files changed

+1584
-1590
lines changed

sycl/test/check_device_code/cuda/ldg.cpp

Lines changed: 206 additions & 276 deletions
Large diffs are not rendered by default.

sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp

Lines changed: 227 additions & 212 deletions
Large diffs are not rendered by default.

sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp

Lines changed: 75 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -15,96 +15,78 @@ constexpr int N = 8; // number of cols of accumulator,
1515
// number of rows of a.
1616
constexpr int K = 4; // number of cols of a/number of rows of b.
1717

18-
double A[M * K];
19-
double B[K * N];
20-
double C[M * N];
21-
double D[M * N];
22-
23-
int main() {
24-
25-
buffer<double, 1> bufA(A, range<1>(M * K));
26-
buffer<double, 1> bufB(B, range<1>(K * N));
27-
buffer<double, 1> bufC(C, range<1>(M * N));
28-
buffer<double, 1> bufD(D, range<1>(M * N));
29-
30-
queue q;
31-
32-
q.submit([&](handler &cgh) {
33-
sycl::accessor<double, 1, sycl::access::mode::read_write,
34-
sycl::target::device>
35-
accA(bufA, cgh);
36-
sycl::accessor<double, 1, sycl::access::mode::read_write,
37-
sycl::target::device>
38-
accB(bufB, cgh);
39-
sycl::accessor<double, 1, sycl::access::mode::read_write,
40-
sycl::target::device>
41-
accC(bufC, cgh);
42-
sycl::accessor<double, 1, sycl::access::mode::read_write,
43-
sycl::target::device>
44-
accD(bufD, cgh);
45-
46-
cgh.parallel_for<class row_row>(
47-
nd_range<2>({1, 32}, {1, 32}),
48-
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
49-
sycl::sub_group sg = item.get_sub_group();
50-
51-
joint_matrix<sub_group, double, use::accumulator, M, N> sub_c{};
52-
joint_matrix<sub_group, double, use::a, M, K, layout::row_major>
53-
sub_a{};
54-
joint_matrix<sub_group, double, use::b, K, N, layout::row_major>
55-
sub_b{};
56-
57-
//CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 8)
58-
joint_matrix_load(
59-
sg, sub_c, accC.template get_multi_ptr<access::decorated::yes>(),
60-
N, layout::row_major);
61-
//CHECK-OPAQUE: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 4)
62-
joint_matrix_load(
63-
sg, sub_a, accA.template get_multi_ptr<access::decorated::yes>(),
64-
K);
65-
//CHECK-OPAQUE: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 8)
66-
joint_matrix_load(
67-
sg, sub_b, accB.template get_multi_ptr<access::decorated::yes>(),
68-
N);
69-
//CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}})
70-
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
71-
//CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8)
72-
joint_matrix_store(
73-
sg, sub_c, accD.template get_multi_ptr<access::decorated::yes>(),
74-
N, layout::row_major);
75-
});
76-
77-
cgh.parallel_for<class col_col>(
78-
nd_range<2>({1, 32}, {1, 32}),
79-
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
80-
sycl::sub_group sg = item.get_sub_group();
81-
82-
joint_matrix<sub_group, double, use::accumulator, M, N> sub_c{};
83-
joint_matrix<sub_group, double, use::a, M, K, layout::col_major>
84-
sub_a{};
85-
joint_matrix<sub_group, double, use::b, K, N, layout::col_major>
86-
sub_b{};
87-
88-
//CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 8)
89-
joint_matrix_load(
90-
sg, sub_c, accC.template get_multi_ptr<access::decorated::yes>(),
91-
M, layout::col_major);
92-
//CHECK-OPAQUE: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 8)
93-
joint_matrix_load(
94-
sg, sub_a, accA.template get_multi_ptr<access::decorated::yes>(),
95-
M);
96-
//CHECK-OPAQUE: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 4)
97-
joint_matrix_load(
98-
sg, sub_b, accB.template get_multi_ptr<access::decorated::yes>(),
99-
K);
100-
//CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}})
101-
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
102-
//CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8)
103-
joint_matrix_store(
104-
sg, sub_c, accD.template get_multi_ptr<access::decorated::yes>(),
105-
M, layout::col_major);
106-
});
107-
});
108-
109-
return 0;
110-
};
18+
SYCL_EXTERNAL [[sycl::reqd_work_group_size(1, 1, 32)]] void
19+
row_row_m8n8k4(sycl::accessor<double, 1, sycl::access::mode::read_write,
20+
sycl::target::device>
21+
accA,
22+
sycl::accessor<double, 1, sycl::access::mode::read_write,
23+
sycl::target::device>
24+
accB,
25+
sycl::accessor<double, 1, sycl::access::mode::read_write,
26+
sycl::target::device>
27+
accC,
28+
sycl::accessor<double, 1, sycl::access::mode::read_write,
29+
sycl::target::device>
30+
accD,
31+
nd_item<2> item) {
32+
sycl::sub_group sg = item.get_sub_group();
33+
34+
joint_matrix<sub_group, double, use::accumulator, M, N> sub_c{};
35+
joint_matrix<sub_group, double, use::a, M, K, layout::row_major> sub_a{};
36+
joint_matrix<sub_group, double, use::b, K, N, layout::row_major> sub_b{};
37+
38+
//CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 8)
39+
joint_matrix_load(sg, sub_c,
40+
accC.template get_multi_ptr<access::decorated::yes>(), N,
41+
layout::row_major);
42+
//CHECK-OPAQUE: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 4)
43+
joint_matrix_load(sg, sub_a,
44+
accA.template get_multi_ptr<access::decorated::yes>(), K);
45+
//CHECK-OPAQUE: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 8)
46+
joint_matrix_load(sg, sub_b,
47+
accB.template get_multi_ptr<access::decorated::yes>(), N);
48+
//CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}})
49+
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
50+
//CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8)
51+
joint_matrix_store(sg, sub_c,
52+
accD.template get_multi_ptr<access::decorated::yes>(), N,
53+
layout::row_major);
54+
}
55+
56+
SYCL_EXTERNAL [[sycl::reqd_work_group_size(1, 1, 32)]] void
57+
col_col_m8n8k4(sycl::accessor<double, 1, sycl::access::mode::read_write,
58+
sycl::target::device>
59+
accA,
60+
sycl::accessor<double, 1, sycl::access::mode::read_write,
61+
sycl::target::device>
62+
accB,
63+
sycl::accessor<double, 1, sycl::access::mode::read_write,
64+
sycl::target::device>
65+
accC,
66+
sycl::accessor<double, 1, sycl::access::mode::read_write,
67+
sycl::target::device>
68+
accD,
69+
nd_item<2> item) {
70+
sycl::sub_group sg = item.get_sub_group();
71+
72+
joint_matrix<sub_group, double, use::accumulator, M, N> sub_c{};
73+
joint_matrix<sub_group, double, use::a, M, K, layout::col_major> sub_a{};
74+
joint_matrix<sub_group, double, use::b, K, N, layout::col_major> sub_b{};
75+
76+
//CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 8)
77+
joint_matrix_load(sg, sub_c,
78+
accC.template get_multi_ptr<access::decorated::yes>(), M,
79+
layout::col_major);
80+
//CHECK-OPAQUE: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 8)
81+
joint_matrix_load(sg, sub_a,
82+
accA.template get_multi_ptr<access::decorated::yes>(), M);
83+
//CHECK-OPAQUE: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 4)
84+
joint_matrix_load(sg, sub_b,
85+
accB.template get_multi_ptr<access::decorated::yes>(), K);
86+
//CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}})
87+
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
88+
//CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8)
89+
joint_matrix_store(sg, sub_c,
90+
accD.template get_multi_ptr<access::decorated::yes>(), M,
91+
layout::col_major);
92+
}

0 commit comments

Comments
 (0)