@@ -51,34 +51,45 @@ struct joint_matrix<
5151} // namespace experimental::matrix
5252
5353namespace detail {
54- using namespace experimental ;
5554
56- template <typename T, matrix::matrix_use MT, size_t NumRows, size_t NumCols,
57- matrix::matrix_layout Layout, access::address_space Space,
58- typename Cond = void >
55+ template <typename T, sycl::ext::oneapi::experimental::matrix::matrix_use MT,
56+ size_t NumRows, size_t NumCols,
57+ sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
58+ access::address_space Space, typename Cond = void >
5959struct joint_matrix_load_impl {
60- void load (matrix::joint_matrix<T, MT, NumRows, NumCols, Layout> &res,
60+ void load (sycl::ext::oneapi::experimental::matrix::joint_matrix<
61+ T, MT, NumRows, NumCols, Layout> &res,
6162 multi_ptr<T, Space> src, size_t stride);
6263};
6364
64- template <matrix::matrix_layout Layout> constexpr int get_layout_id ();
65+ template <sycl::ext::oneapi::experimental::matrix::matrix_layout Layout>
66+ constexpr int get_layout_id ();
6567
66- template <> constexpr int get_layout_id<matrix::matrix_layout::row_major>() {
68+ template <>
69+ constexpr int get_layout_id<
70+ sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major>() {
6771 return 0 ;
6872}
6973
70- template <> constexpr int get_layout_id<matrix::matrix_layout::col_major>() {
74+ template <>
75+ constexpr int get_layout_id<
76+ sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major>() {
7177 return 1 ;
7278}
7379
74- template <matrix::matrix_layout Layout, access::address_space Space>
80+ template <sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
81+ access::address_space Space>
7582struct joint_matrix_load_impl <
76- double , matrix::matrix_use::a, 8 , 4 , Layout, Space,
77- typename std::enable_if_t <Layout == matrix::matrix_layout::row_major ||
78- Layout == matrix::matrix_layout::col_major>> {
79- void
80- load (matrix::joint_matrix<double , matrix::matrix_use::a, 8 , 4 , Layout> &res,
81- multi_ptr<double , Space> src, size_t stride) {
83+ double , sycl::ext::oneapi::experimental::matrix::matrix_use::a, 8 , 4 ,
84+ Layout, Space,
85+ typename std::enable_if_t <Layout == sycl::ext::oneapi::experimental::
86+ matrix::matrix_layout::row_major ||
87+ Layout == sycl::ext::oneapi::experimental::
88+ matrix::matrix_layout::col_major>> {
89+ void load (sycl::ext::oneapi::experimental::matrix::joint_matrix<
90+ double , sycl::ext::oneapi::experimental::matrix::matrix_use::a,
91+ 8 , 4 , Layout> &res,
92+ multi_ptr<double , Space> src, size_t stride) {
8293
8394#ifdef __NVPTX__
8495#ifdef __SYCL_DEVICE_ONLY__
@@ -88,14 +99,19 @@ struct joint_matrix_load_impl<
8899 }
89100};
90101
91- template <matrix::matrix_layout Layout, access::address_space Space>
102+ template <sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
103+ access::address_space Space>
92104struct joint_matrix_load_impl <
93- double , matrix::matrix_use::b, 4 , 8 , Layout, Space,
94- typename std::enable_if_t <Layout == matrix::matrix_layout::row_major ||
95- Layout == matrix::matrix_layout::col_major>> {
96- void
97- load (matrix::joint_matrix<double , matrix::matrix_use::b, 4 , 8 , Layout> &res,
98- multi_ptr<double , Space> src, size_t stride) {
105+ double , sycl::ext::oneapi::experimental::matrix::matrix_use::b, 4 , 8 ,
106+ Layout, Space,
107+ typename std::enable_if_t <Layout == sycl::ext::oneapi::experimental::
108+ matrix::matrix_layout::row_major ||
109+ Layout == sycl::ext::oneapi::experimental::
110+ matrix::matrix_layout::col_major>> {
111+ void load (sycl::ext::oneapi::experimental::matrix::joint_matrix<
112+ double , sycl::ext::oneapi::experimental::matrix::matrix_use::b,
113+ 4 , 8 , Layout> &res,
114+ multi_ptr<double , Space> src, size_t stride) {
99115#ifdef __NVPTX__
100116#ifdef __SYCL_DEVICE_ONLY__
101117 __dmma_m8n8k4_ld_b (res.data , src.get (), stride, get_layout_id<Layout>());
@@ -104,14 +120,21 @@ struct joint_matrix_load_impl<
104120 }
105121};
106122
107- template <matrix::matrix_layout Layout, access::address_space Space>
123+ template <sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
124+ access::address_space Space>
108125struct joint_matrix_load_impl <
109- double , matrix::matrix_use::accumulator, 8 , 8 , Layout, Space,
110- typename std::enable_if_t <Layout == matrix::matrix_layout::row_major ||
111- Layout == matrix::matrix_layout::col_major>> {
112- void load (matrix::joint_matrix<double , matrix::matrix_use::accumulator, 8 , 8 ,
113- Layout> &res,
114- multi_ptr<double , Space> src, size_t stride) {
126+ double , sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8 ,
127+ 8 , Layout, Space,
128+ typename std::enable_if_t <Layout == sycl::ext::oneapi::experimental::
129+ matrix::matrix_layout::row_major ||
130+ Layout == sycl::ext::oneapi::experimental::
131+ matrix::matrix_layout::col_major>> {
132+ void
133+ load (sycl::ext::oneapi::experimental::matrix::joint_matrix<
134+ double ,
135+ sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8 ,
136+ 8 , Layout> &res,
137+ multi_ptr<double , Space> src, size_t stride) {
115138
116139#ifdef __NVPTX__
117140#ifdef __SYCL_DEVICE_ONLY__
@@ -122,22 +145,30 @@ struct joint_matrix_load_impl<
122145};
123146
124147template <typename T, size_t NumRows, size_t NumCols,
125- matrix::matrix_layout Layout, access::address_space Space ,
126- typename Cond = void >
148+ sycl::ext::oneapi::experimental:: matrix::matrix_layout Layout,
149+ access::address_space Space, typename Cond = void >
127150struct joint_matrix_store_impl {
128- void store (matrix::joint_matrix<T, matrix::matrix_use::accumulator, NumRows,
129- NumCols, Layout> &src,
130- multi_ptr<T, Space> dst, size_t stride);
151+ void
152+ store (sycl::ext::oneapi::experimental::matrix::joint_matrix<
153+ T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
154+ NumRows, NumCols, Layout> &src,
155+ multi_ptr<T, Space> dst, size_t stride);
131156};
132157
133- template <matrix::matrix_layout Layout, access::address_space Space>
158+ template <sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
159+ access::address_space Space>
134160struct joint_matrix_store_impl <
135161 double , 8 , 8 , Layout, Space,
136- typename std::enable_if_t <Layout == matrix::matrix_layout::row_major ||
137- Layout == matrix::matrix_layout::col_major>> {
138- void store (matrix::joint_matrix<double , matrix::matrix_use::accumulator, 8 , 8 ,
139- Layout> &src,
140- multi_ptr<double , Space> dst, size_t stride) {
162+ typename std::enable_if_t <Layout == sycl::ext::oneapi::experimental::
163+ matrix::matrix_layout::row_major ||
164+ Layout == sycl::ext::oneapi::experimental::
165+ matrix::matrix_layout::col_major>> {
166+ void
167+ store (sycl::ext::oneapi::experimental::matrix::joint_matrix<
168+ double ,
169+ sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8 ,
170+ 8 , Layout> &src,
171+ multi_ptr<double , Space> dst, size_t stride) {
141172
142173#ifdef __NVPTX__
143174#ifdef __SYCL_DEVICE_ONLY__
@@ -149,60 +180,98 @@ struct joint_matrix_store_impl<
149180};
150181
151182template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
152- matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB,
153- matrix::matrix_layout LayoutC, typename Cond = void >
183+ sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
184+ sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB,
185+ sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC,
186+ typename Cond = void >
154187struct joint_matrix_mad_impl {
155- matrix::joint_matrix<T2, matrix::matrix_use::accumulator, M, N, LayoutC>
156- mad (matrix::joint_matrix<T1, matrix::matrix_use::a, M, K, LayoutA> A,
157- matrix::joint_matrix<T1, matrix::matrix_use::b, K, N, LayoutB> B,
158- matrix::joint_matrix<T2, matrix::matrix_use::accumulator, M, N, LayoutC>
188+ sycl::ext::oneapi::experimental::matrix::joint_matrix<
189+ T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
190+ N, LayoutC>
191+ mad (sycl::ext::oneapi::experimental::matrix::joint_matrix<
192+ T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K,
193+ LayoutA>
194+ A,
195+ sycl::ext::oneapi::experimental::matrix::joint_matrix<
196+ T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N,
197+ LayoutB>
198+ B,
199+ sycl::ext::oneapi::experimental::matrix::joint_matrix<
200+ T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
201+ M, N, LayoutC>
159202 C);
160203};
161204
162- template <matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB>
205+ template <sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
206+ sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB>
163207constexpr int get_layout_pair_id ();
164208
165209template <>
166- constexpr int get_layout_pair_id<matrix::matrix_layout::row_major,
167- matrix::matrix_layout::row_major>() {
210+ constexpr int get_layout_pair_id<
211+ sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major,
212+ sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major>() {
168213 return 0 ;
169214}
170215
171216template <>
172- constexpr int get_layout_pair_id<matrix::matrix_layout::row_major,
173- matrix::matrix_layout::col_major>() {
217+ constexpr int get_layout_pair_id<
218+ sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major,
219+ sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major>() {
174220 return 1 ;
175221}
176222
177223template <>
178- constexpr int get_layout_pair_id<matrix::matrix_layout::col_major,
179- matrix::matrix_layout::row_major>() {
224+ constexpr int get_layout_pair_id<
225+ sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major,
226+ sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major>() {
180227 return 2 ;
181228}
182229
183230template <>
184- constexpr int get_layout_pair_id<matrix::matrix_layout::col_major,
185- matrix::matrix_layout::col_major>() {
231+ constexpr int get_layout_pair_id<
232+ sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major,
233+ sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major>() {
186234 return 3 ;
187235}
188236
189- template <matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB,
190- matrix::matrix_layout LayoutC>
237+ template <sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
238+ sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB,
239+ sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC>
191240struct joint_matrix_mad_impl <
192241 double , double , 8 , 4 , 8 , LayoutA, LayoutB, LayoutC,
193- typename std::enable_if_t <(LayoutA == matrix::matrix_layout::row_major ||
194- LayoutA == matrix::matrix_layout::col_major) &&
195- (LayoutB == matrix::matrix_layout::row_major ||
196- LayoutB == matrix::matrix_layout::col_major) &&
197- (LayoutC == matrix::matrix_layout::row_major ||
198- LayoutC == matrix::matrix_layout::col_major)>> {
199- matrix::joint_matrix<double , matrix::matrix_use::accumulator, 8 , 8 , LayoutC>
200- mad (matrix::joint_matrix<double , matrix::matrix_use::a, 8 , 4 , LayoutA> A,
201- matrix::joint_matrix<double , matrix::matrix_use::b, 4 , 8 , LayoutB> B,
202- matrix::joint_matrix<double , matrix::matrix_use::accumulator, 8 , 8 ,
203- LayoutC>
242+ typename std::enable_if_t <
243+ (LayoutA == sycl::ext::oneapi::experimental::matrix::matrix_layout::
244+ row_major ||
245+ LayoutA == sycl::ext::oneapi::experimental::matrix::matrix_layout::
246+ col_major) &&
247+ (LayoutB == sycl::ext::oneapi::experimental::matrix::matrix_layout::
248+ row_major ||
249+ LayoutB == sycl::ext::oneapi::experimental::matrix::matrix_layout::
250+ col_major) &&
251+ (LayoutC == sycl::ext::oneapi::experimental::matrix::matrix_layout::
252+ row_major ||
253+ LayoutC == sycl::ext::oneapi::experimental::matrix::matrix_layout::
254+ col_major)>> {
255+ sycl::ext::oneapi::experimental::matrix::joint_matrix<
256+ double , sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
257+ 8 , 8 , LayoutC>
258+ mad (sycl::ext::oneapi::experimental::matrix::joint_matrix<
259+ double , sycl::ext::oneapi::experimental::matrix::matrix_use::a, 8 , 4 ,
260+ LayoutA>
261+ A,
262+ sycl::ext::oneapi::experimental::matrix::joint_matrix<
263+ double , sycl::ext::oneapi::experimental::matrix::matrix_use::b, 4 , 8 ,
264+ LayoutB>
265+ B,
266+ sycl::ext::oneapi::experimental::matrix::joint_matrix<
267+ double ,
268+ sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8 ,
269+ 8 , LayoutC>
204270 C) {
205- matrix::joint_matrix<double , matrix::matrix_use::accumulator, 8 , 8 , LayoutC>
271+ sycl::ext::oneapi::experimental::matrix::joint_matrix<
272+ double ,
273+ sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8 , 8 ,
274+ LayoutC>
206275 D;
207276
208277#ifdef __NVPTX__
@@ -225,8 +294,9 @@ template <typename Group, typename T, matrix_use MT, size_t NumRows,
225294void joint_matrix_load (
226295 Group sg, joint_matrix<T, MT, NumRows, NumCols, Layout, Group> &res,
227296 multi_ptr<T, Space> src, size_t stride) {
228- detail::joint_matrix_load_impl<T, MT, NumRows, NumCols, Layout, Space>{}.load (
229- res, src, stride);
297+ sycl::ext::oneapi::detail::joint_matrix_load_impl<T, MT, NumRows, NumCols,
298+ Layout, Space>{}
299+ .load (res, src, stride);
230300}
231301
232302template <typename Group, typename T, size_t NumRows, size_t NumCols,
@@ -235,8 +305,9 @@ void joint_matrix_store(Group sg,
235305 joint_matrix<T, matrix_use::accumulator, NumRows,
236306 NumCols, Layout, Group> &src,
237307 multi_ptr<T, Space> dst, size_t stride) {
238- detail::joint_matrix_store_impl<T, NumRows, NumCols, Layout, Space>{}.store (
239- src, dst, stride);
308+ sycl::ext::oneapi::detail::joint_matrix_store_impl<T, NumRows, NumCols,
309+ Layout, Space>{}
310+ .store (src, dst, stride);
240311}
241312
242313template <typename Group, typename T1, typename T2, std::size_t M,
@@ -247,8 +318,8 @@ joint_matrix_mad(
247318 Group sg, joint_matrix<T1, matrix_use::a, M, K, LayoutA, Group> A,
248319 joint_matrix<T1, matrix_use::b, K, N, LayoutB, Group> B,
249320 joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group> C) {
250- return detail::joint_matrix_mad_impl<T1, T2, M, K, N, LayoutA, LayoutB,
251- LayoutC>{}
321+ return sycl::ext::oneapi:: detail::joint_matrix_mad_impl<
322+ T1, T2, M, K, N, LayoutA, LayoutB, LayoutC>{}
252323 .mad (A, B, C);
253324}
254325
0 commit comments