@@ -15,96 +15,78 @@ constexpr int N = 8; // number of cols of accumulator,
1515 // number of rows of a.
1616constexpr int K = 4 ; // number of cols of a/number of rows of b.
1717
18- double A[M * K];
19- double B[K * N];
20- double C[M * N];
21- double D[M * N];
22-
23- int main () {
24-
25- buffer<double , 1 > bufA (A, range<1 >(M * K));
26- buffer<double , 1 > bufB (B, range<1 >(K * N));
27- buffer<double , 1 > bufC (C, range<1 >(M * N));
28- buffer<double , 1 > bufD (D, range<1 >(M * N));
29-
30- queue q;
31-
32- q.submit ([&](handler &cgh) {
33- sycl::accessor<double , 1 , sycl::access::mode::read_write,
34- sycl::target::device>
35- accA (bufA, cgh);
36- sycl::accessor<double , 1 , sycl::access::mode::read_write,
37- sycl::target::device>
38- accB (bufB, cgh);
39- sycl::accessor<double , 1 , sycl::access::mode::read_write,
40- sycl::target::device>
41- accC (bufC, cgh);
42- sycl::accessor<double , 1 , sycl::access::mode::read_write,
43- sycl::target::device>
44- accD (bufD, cgh);
45-
46- cgh.parallel_for <class row_row >(
47- nd_range<2 >({1 , 32 }, {1 , 32 }),
48- [=](nd_item<2 > item) [[sycl::reqd_work_group_size (1 , 1 , 32 )]] {
49- sycl::sub_group sg = item.get_sub_group ();
50-
51- joint_matrix<sub_group, double , use::accumulator, M, N> sub_c{};
52- joint_matrix<sub_group, double , use::a, M, K, layout::row_major>
53- sub_a{};
54- joint_matrix<sub_group, double , use::b, K, N, layout::row_major>
55- sub_b{};
56-
57- // CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 8)
58- joint_matrix_load (
59- sg, sub_c, accC.template get_multi_ptr <access::decorated::yes>(),
60- N, layout::row_major);
61- // CHECK-OPAQUE: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 4)
62- joint_matrix_load (
63- sg, sub_a, accA.template get_multi_ptr <access::decorated::yes>(),
64- K);
65- // CHECK-OPAQUE: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 8)
66- joint_matrix_load (
67- sg, sub_b, accB.template get_multi_ptr <access::decorated::yes>(),
68- N);
69- // CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}})
70- joint_matrix_mad (sg, sub_c, sub_a, sub_b, sub_c);
71- // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8)
72- joint_matrix_store (
73- sg, sub_c, accD.template get_multi_ptr <access::decorated::yes>(),
74- N, layout::row_major);
75- });
76-
77- cgh.parallel_for <class col_col >(
78- nd_range<2 >({1 , 32 }, {1 , 32 }),
79- [=](nd_item<2 > item) [[sycl::reqd_work_group_size (1 , 1 , 32 )]] {
80- sycl::sub_group sg = item.get_sub_group ();
81-
82- joint_matrix<sub_group, double , use::accumulator, M, N> sub_c{};
83- joint_matrix<sub_group, double , use::a, M, K, layout::col_major>
84- sub_a{};
85- joint_matrix<sub_group, double , use::b, K, N, layout::col_major>
86- sub_b{};
87-
88- // CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 8)
89- joint_matrix_load (
90- sg, sub_c, accC.template get_multi_ptr <access::decorated::yes>(),
91- M, layout::col_major);
92- // CHECK-OPAQUE: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 8)
93- joint_matrix_load (
94- sg, sub_a, accA.template get_multi_ptr <access::decorated::yes>(),
95- M);
96- // CHECK-OPAQUE: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 4)
97- joint_matrix_load (
98- sg, sub_b, accB.template get_multi_ptr <access::decorated::yes>(),
99- K);
100- // CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}})
101- joint_matrix_mad (sg, sub_c, sub_a, sub_b, sub_c);
102- // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8)
103- joint_matrix_store (
104- sg, sub_c, accD.template get_multi_ptr <access::decorated::yes>(),
105- M, layout::col_major);
106- });
107- });
108-
109- return 0 ;
110- };
18+ SYCL_EXTERNAL [[sycl::reqd_work_group_size(1 , 1 , 32 )]] void
19+ row_row_m8n8k4 (sycl::accessor<double , 1 , sycl::access::mode::read_write,
20+ sycl::target::device>
21+ accA,
22+ sycl::accessor<double , 1 , sycl::access::mode::read_write,
23+ sycl::target::device>
24+ accB,
25+ sycl::accessor<double , 1 , sycl::access::mode::read_write,
26+ sycl::target::device>
27+ accC,
28+ sycl::accessor<double , 1 , sycl::access::mode::read_write,
29+ sycl::target::device>
30+ accD,
31+ nd_item<2 > item) {
32+ sycl::sub_group sg = item.get_sub_group ();
33+
34+ joint_matrix<sub_group, double , use::accumulator, M, N> sub_c{};
35+ joint_matrix<sub_group, double , use::a, M, K, layout::row_major> sub_a{};
36+ joint_matrix<sub_group, double , use::b, K, N, layout::row_major> sub_b{};
37+
38+ // CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 8)
39+ joint_matrix_load (sg, sub_c,
40+ accC.template get_multi_ptr <access::decorated::yes>(), N,
41+ layout::row_major);
42+ // CHECK-OPAQUE: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 4)
43+ joint_matrix_load (sg, sub_a,
44+ accA.template get_multi_ptr <access::decorated::yes>(), K);
45+ // CHECK-OPAQUE: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 8)
46+ joint_matrix_load (sg, sub_b,
47+ accB.template get_multi_ptr <access::decorated::yes>(), N);
48+ // CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}})
49+ joint_matrix_mad (sg, sub_c, sub_a, sub_b, sub_c);
50+ // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8)
51+ joint_matrix_store (sg, sub_c,
52+ accD.template get_multi_ptr <access::decorated::yes>(), N,
53+ layout::row_major);
54+ }
55+
56+ SYCL_EXTERNAL [[sycl::reqd_work_group_size(1 , 1 , 32 )]] void
57+ col_col_m8n8k4 (sycl::accessor<double , 1 , sycl::access::mode::read_write,
58+ sycl::target::device>
59+ accA,
60+ sycl::accessor<double , 1 , sycl::access::mode::read_write,
61+ sycl::target::device>
62+ accB,
63+ sycl::accessor<double , 1 , sycl::access::mode::read_write,
64+ sycl::target::device>
65+ accC,
66+ sycl::accessor<double , 1 , sycl::access::mode::read_write,
67+ sycl::target::device>
68+ accD,
69+ nd_item<2 > item) {
70+ sycl::sub_group sg = item.get_sub_group ();
71+
72+ joint_matrix<sub_group, double , use::accumulator, M, N> sub_c{};
73+ joint_matrix<sub_group, double , use::a, M, K, layout::col_major> sub_a{};
74+ joint_matrix<sub_group, double , use::b, K, N, layout::col_major> sub_b{};
75+
76+ // CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 8)
77+ joint_matrix_load (sg, sub_c,
78+ accC.template get_multi_ptr <access::decorated::yes>(), M,
79+ layout::col_major);
80+ // CHECK-OPAQUE: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 8)
81+ joint_matrix_load (sg, sub_a,
82+ accA.template get_multi_ptr <access::decorated::yes>(), M);
83+ // CHECK-OPAQUE: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 4)
84+ joint_matrix_load (sg, sub_b,
85+ accB.template get_multi_ptr <access::decorated::yes>(), K);
86+ // CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}})
87+ joint_matrix_mad (sg, sub_c, sub_a, sub_b, sub_c);
88+ // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8)
89+ joint_matrix_store (sg, sub_c,
90+ accD.template get_multi_ptr <access::decorated::yes>(), M,
91+ layout::col_major);
92+ }
0 commit comments