@@ -219,10 +219,11 @@ struct joint_matrix_load_impl<
219219 void load (sycl::ext::oneapi::experimental::matrix::joint_matrix<
220220 S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
221221 multi_ptr<T, Space> src, size_t stride) {
222- if constexpr (std::is_same<T , uint16_t >::value ||
222+ if constexpr (std::is_same<std:: remove_const_t <T> , uint16_t >::value ||
223223 std::is_same<
224- T, sycl::ext::oneapi::experimental::bfloat16>::value) {
225- auto tileptr = reinterpret_cast <int32_t const *>(src.get ());
224+ std::remove_const_t <T>,
225+ sycl::ext::oneapi::experimental::bfloat16>::value) {
226+ auto tileptr = reinterpret_cast <const int32_t *>(src.get ());
226227 auto destptr = reinterpret_cast <int32_t *>(&res.wi_marray );
227228 if constexpr (NumRows == 16 && NumCols == 16 ) {
228229 if constexpr (Use ==
@@ -247,8 +248,8 @@ struct joint_matrix_load_impl<
247248 __mma_bf16_m32n8k16_ld_b (destptr, tileptr, stride,
248249 get_layout_id<Layout>());
249250 }
250- } else if constexpr (std::is_same<T , uint8_t >::value) {
251- auto tileptr = reinterpret_cast <int32_t const *>(src.get ());
251+ } else if constexpr (std::is_same<std:: remove_const_t <T> , uint8_t >::value) {
252+ auto tileptr = reinterpret_cast <const int32_t *>(src.get ());
252253 auto destptr = reinterpret_cast <int32_t *>(&res.wi_marray );
253254 if constexpr (NumRows == 16 && NumCols == 16 ) {
254255 if constexpr (Use ==
@@ -273,8 +274,8 @@ struct joint_matrix_load_impl<
273274 __imma_m32n8k16_ld_b_u8 (destptr, tileptr, stride,
274275 get_layout_id<Layout>());
275276 }
276- } else if constexpr (std::is_same<T , int8_t >::value) {
277- auto tileptr = reinterpret_cast <int32_t const *>(src.get ());
277+ } else if constexpr (std::is_same<std:: remove_const_t <T> , int8_t >::value) {
278+ auto tileptr = reinterpret_cast <const int32_t *>(src.get ());
278279 auto destptr = reinterpret_cast <int32_t *>(&res.wi_marray );
279280 if constexpr (NumRows == 16 && NumCols == 16 ) {
280281 if constexpr (Use ==
@@ -299,8 +300,8 @@ struct joint_matrix_load_impl<
299300 __imma_m32n8k16_ld_b_s8 (destptr, tileptr, stride,
300301 get_layout_id<Layout>());
301302 }
302- } else if constexpr (std::is_same<T , half>::value) {
303- auto tileptr = reinterpret_cast <int32_t const *>(src.get ());
303+ } else if constexpr (std::is_same<std:: remove_const_t <T> , half>::value) {
304+ auto tileptr = reinterpret_cast <const int32_t *>(src.get ());
304305 auto dstptr = reinterpret_cast <int32_t *>(&res.wi_marray );
305306 if constexpr (NumRows == 16 && NumCols == 16 ) {
306307 if constexpr (Use ==
@@ -332,7 +333,7 @@ struct joint_matrix_load_impl<
332333 get_layout_id<Layout>());
333334 }
334335
335- } else if constexpr (std::is_same<T , int32_t >::value) {
336+ } else if constexpr (std::is_same<std:: remove_const_t <T> , int32_t >::value) {
336337 auto destptr = reinterpret_cast <int32_t *>(&res.wi_marray );
337338 if constexpr (NumRows == 16 && NumCols == 16 ) {
338339 __imma_m16n16k16_ld_c (destptr, src.get (), stride,
@@ -344,7 +345,7 @@ struct joint_matrix_load_impl<
344345 __imma_m32n8k16_ld_c (destptr, src.get (), stride,
345346 get_layout_id<Layout>());
346347 }
347- } else if constexpr (std::is_same<T , float >::value) {
348+ } else if constexpr (std::is_same<std:: remove_const_t <T> , float >::value) {
348349 if constexpr (std::is_same<S, float >::value) {
349350 auto dstptr = reinterpret_cast <float *>(&res.wi_marray );
350351 if constexpr (NumRows == 16 && NumCols == 16 ) {
@@ -360,7 +361,7 @@ struct joint_matrix_load_impl<
360361 } else if constexpr (std::is_same<S,
361362 sycl::ext::oneapi::experimental::
362363 matrix::precision::tf32>::value) {
363- auto tileptr = reinterpret_cast <int32_t *>(src.get ());
364+ auto tileptr = reinterpret_cast <const int32_t *>(src.get ());
364365 auto dstptr = reinterpret_cast <int32_t *>(&res.wi_marray );
365366 if constexpr (NumRows == 16 && NumCols == 8 ) {
366367 __mma_tf32_m16n16k8_ld_a (dstptr, tileptr, stride,
@@ -370,7 +371,7 @@ struct joint_matrix_load_impl<
370371 get_layout_id<Layout>());
371372 }
372373 }
373- } else if constexpr (std::is_same<T , double >::value) {
374+ } else if constexpr (std::is_same<std:: remove_const_t <T> , double >::value) {
374375 auto dstptr = reinterpret_cast <double *>(&res.wi_marray );
375376 if constexpr (Use ==
376377 sycl::ext::oneapi::experimental::matrix::matrix_use::a) {
@@ -560,9 +561,9 @@ struct joint_matrix_mad_impl<
560561 D;
561562 if constexpr (M == 16 && N == 16 && K == 16 ) {
562563 if constexpr (std::is_same<T2, int32_t >::value) {
563- auto ptrA = reinterpret_cast <int32_t const *>(&A.wi_marray );
564- auto ptrB = reinterpret_cast <int32_t const *>(&B.wi_marray );
565- auto ptrC = reinterpret_cast <int32_t const *>(&C.wi_marray );
564+ auto ptrA = reinterpret_cast <const int32_t *>(&A.wi_marray );
565+ auto ptrB = reinterpret_cast <const int32_t *>(&B.wi_marray );
566+ auto ptrC = reinterpret_cast <const int32_t *>(&C.wi_marray );
566567 auto ptrD = reinterpret_cast <int32_t *>(&D.wi_marray );
567568 if constexpr (std::is_same<T1, int8_t >::value) {
568569 __imma_m16n16k16_mma_s8 (ptrD, ptrA, ptrB, ptrC,
@@ -572,34 +573,34 @@ struct joint_matrix_mad_impl<
572573 get_layout_pair_id<LayoutA, LayoutB>(), 0 );
573574 }
574575 } else if constexpr (std::is_same<T1, half>::value) {
575- auto ptrA = reinterpret_cast <int32_t const *>(&A.wi_marray );
576- auto ptrB = reinterpret_cast <int32_t const *>(&B.wi_marray );
576+ auto ptrA = reinterpret_cast <const int32_t *>(&A.wi_marray );
577+ auto ptrB = reinterpret_cast <const int32_t *>(&B.wi_marray );
577578 if constexpr (std::is_same<T2, float >::value) {
578579 __hmma_m16n16k16_mma_f32f32 (
579580 reinterpret_cast <float *>(&D.wi_marray ), ptrA, ptrB,
580- reinterpret_cast <float const *>(&C.wi_marray ),
581+ reinterpret_cast <const float *>(&C.wi_marray ),
581582 get_layout_pair_id<LayoutA, LayoutB>(), 0 );
582583 } else if constexpr (std::is_same<T2, half>::value) {
583584 __hmma_m16n16k16_mma_f16f16 (
584585 reinterpret_cast <int32_t *>(&D.wi_marray ), ptrA, ptrB,
585- reinterpret_cast <int32_t const *>(&C.wi_marray ),
586+ reinterpret_cast <const int32_t *>(&C.wi_marray ),
586587 get_layout_pair_id<LayoutA, LayoutB>(), 0 );
587588 }
588589 } else if constexpr (std::is_same<T1, uint16_t >::value ||
589590 std::is_same<T1, sycl::ext::oneapi::experimental::
590591 bfloat16>::value) {
591592 __mma_bf16_m16n16k16_mma_f32 (
592593 reinterpret_cast <float *>(&D.wi_marray ),
593- reinterpret_cast <int32_t const *>(&A.wi_marray ),
594- reinterpret_cast <int32_t const *>(&B.wi_marray ),
595- reinterpret_cast <float const *>(&C.wi_marray ),
594+ reinterpret_cast <const int32_t *>(&A.wi_marray ),
595+ reinterpret_cast <const int32_t *>(&B.wi_marray ),
596+ reinterpret_cast <const float *>(&C.wi_marray ),
596597 get_layout_pair_id<LayoutA, LayoutB>(), 0 );
597598 }
598599 } else if constexpr (M == 8 && N == 32 && K == 16 ) {
599600 if constexpr (std::is_same<T2, int32_t >::value) {
600- auto ptrA = reinterpret_cast <int32_t const *>(&A.wi_marray );
601- auto ptrB = reinterpret_cast <int32_t const *>(&B.wi_marray );
602- auto ptrC = reinterpret_cast <int32_t const *>(&C.wi_marray );
601+ auto ptrA = reinterpret_cast <const int32_t *>(&A.wi_marray );
602+ auto ptrB = reinterpret_cast <const int32_t *>(&B.wi_marray );
603+ auto ptrC = reinterpret_cast <const int32_t *>(&C.wi_marray );
603604 auto ptrD = reinterpret_cast <int32_t *>(&D.wi_marray );
604605 if constexpr (std::is_same<T1, int8_t >::value) {
605606 __imma_m8n32k16_mma_s8 (ptrD, ptrA, ptrB, ptrC,
@@ -609,34 +610,34 @@ struct joint_matrix_mad_impl<
609610 get_layout_pair_id<LayoutA, LayoutB>(), 0 );
610611 }
611612 } else if constexpr (std::is_same<T1, half>::value) {
612- auto ptrA = reinterpret_cast <int32_t const *>(&A.wi_marray );
613- auto ptrB = reinterpret_cast <int32_t const *>(&B.wi_marray );
613+ auto ptrA = reinterpret_cast <const int32_t *>(&A.wi_marray );
614+ auto ptrB = reinterpret_cast <const int32_t *>(&B.wi_marray );
614615 if constexpr (std::is_same<T2, float >::value) {
615616 __hmma_m8n32k16_mma_f32f32 (
616617 reinterpret_cast <float *>(&D.wi_marray ), ptrA, ptrB,
617- reinterpret_cast <float const *>(&C.wi_marray ),
618+ reinterpret_cast <const float *>(&C.wi_marray ),
618619 get_layout_pair_id<LayoutA, LayoutB>(), 0 );
619620 } else if constexpr (std::is_same<T2, half>::value) {
620621 __hmma_m8n32k16_mma_f16f16 (
621622 reinterpret_cast <int32_t *>(&D.wi_marray ), ptrA, ptrB,
622- reinterpret_cast <int32_t const *>(&C.wi_marray ),
623+ reinterpret_cast <const int32_t *>(&C.wi_marray ),
623624 get_layout_pair_id<LayoutA, LayoutB>(), 0 );
624625 }
625626 } else if constexpr (std::is_same<T1, uint16_t >::value ||
626627 std::is_same<T1, sycl::ext::oneapi::experimental::
627628 bfloat16>::value) {
628629 __mma_bf16_m8n32k16_mma_f32 (
629630 reinterpret_cast <float *>(&D.wi_marray ),
630- reinterpret_cast <int32_t const *>(&A.wi_marray ),
631- reinterpret_cast <int32_t const *>(&B.wi_marray ),
632- reinterpret_cast <float const *>(&C.wi_marray ),
631+ reinterpret_cast <const int32_t *>(&A.wi_marray ),
632+ reinterpret_cast <const int32_t *>(&B.wi_marray ),
633+ reinterpret_cast <const float *>(&C.wi_marray ),
633634 get_layout_pair_id<LayoutA, LayoutB>(), 0 );
634635 }
635636 } else if constexpr (M == 32 && N == 8 && K == 16 ) {
636637 if constexpr (std::is_same<T2, int32_t >::value) {
637- auto ptrA = reinterpret_cast <int32_t const *>(&A.wi_marray );
638- auto ptrB = reinterpret_cast <int32_t const *>(&B.wi_marray );
639- auto ptrC = reinterpret_cast <int32_t const *>(&C.wi_marray );
638+ auto ptrA = reinterpret_cast <const int32_t *>(&A.wi_marray );
639+ auto ptrB = reinterpret_cast <const int32_t *>(&B.wi_marray );
640+ auto ptrC = reinterpret_cast <const int32_t *>(&C.wi_marray );
640641 auto ptrD = reinterpret_cast <int32_t *>(&D.wi_marray );
641642 if constexpr (std::is_same<T1, int8_t >::value) {
642643 __imma_m32n8k16_mma_s8 (ptrD, ptrA, ptrB, ptrC,
@@ -650,22 +651,22 @@ struct joint_matrix_mad_impl<
650651 bfloat16>::value) {
651652 __mma_bf16_m32n8k16_mma_f32 (
652653 reinterpret_cast <float *>(&D.wi_marray ),
653- reinterpret_cast <int32_t const *>(&A.wi_marray ),
654- reinterpret_cast <int32_t const *>(&B.wi_marray ),
655- reinterpret_cast <float const *>(&C.wi_marray ),
654+ reinterpret_cast <const int32_t *>(&A.wi_marray ),
655+ reinterpret_cast <const int32_t *>(&B.wi_marray ),
656+ reinterpret_cast <const float *>(&C.wi_marray ),
656657 get_layout_pair_id<LayoutA, LayoutB>(), 0 );
657658 } else if constexpr (std::is_same<T1, half>::value) {
658- auto ptrA = reinterpret_cast <int32_t const *>(&A.wi_marray );
659- auto ptrB = reinterpret_cast <int32_t const *>(&B.wi_marray );
659+ auto ptrA = reinterpret_cast <const int32_t *>(&A.wi_marray );
660+ auto ptrB = reinterpret_cast <const int32_t *>(&B.wi_marray );
660661 if constexpr (std::is_same<T2, float >::value) {
661662 __hmma_m32n8k16_mma_f32f32 (
662663 reinterpret_cast <float *>(&D.wi_marray ), ptrA, ptrB,
663- reinterpret_cast <float const *>(&C.wi_marray ),
664+ reinterpret_cast <const float *>(&C.wi_marray ),
664665 get_layout_pair_id<LayoutA, LayoutB>(), 0 );
665666 } else if constexpr (std::is_same<T2, half>::value) {
666667 __hmma_m32n8k16_mma_f16f16 (
667668 reinterpret_cast <int32_t *>(&D.wi_marray ), ptrA, ptrB,
668- reinterpret_cast <int32_t const *>(&C.wi_marray ),
669+ reinterpret_cast <const int32_t *>(&C.wi_marray ),
669670 get_layout_pair_id<LayoutA, LayoutB>(), 0 );
670671 }
671672 }
@@ -677,9 +678,9 @@ struct joint_matrix_mad_impl<
677678 get_layout_pair_id<LayoutA, LayoutB>(), 0 );
678679 } else if constexpr (std::is_same<T1, double >::value) {
679680 __dmma_m8n8k4_mma_f64 (reinterpret_cast <double *>(&D.wi_marray ),
680- reinterpret_cast <double const *>(&A.wi_marray ),
681- reinterpret_cast <double const *>(&B.wi_marray ),
682- reinterpret_cast <double const *>(&C.wi_marray ),
681+ reinterpret_cast <const double *>(&A.wi_marray ),
682+ reinterpret_cast <const double *>(&B.wi_marray ),
683+ reinterpret_cast <const double *>(&C.wi_marray ),
683684 get_layout_pair_id<LayoutA, LayoutB>(), 0 );
684685 }
685686 return D;
@@ -692,13 +693,14 @@ struct joint_matrix_mad_impl<
692693namespace experimental {
693694namespace matrix {
694695
695- template <typename Group, typename S, typename T, matrix_use Use,
696- size_t NumRows, size_t NumCols, matrix_layout Layout,
697- access::address_space Space,
698- std::enable_if_t <std::is_same<S, T>::value ||
699- (std::is_same<S, precision::tf32>::value &&
700- std::is_same<T, float >::value),
701- bool > = true >
696+ template <
697+ typename Group, typename S, typename T, matrix_use Use, size_t NumRows,
698+ size_t NumCols, matrix_layout Layout, access::address_space Space,
699+ std::enable_if_t <std::is_same<S, std::remove_const_t <T>>::value ||
700+ (std::is_same<S, precision::tf32>::value &&
701+
702+ std::is_same<std::remove_const_t <T>, float >::value),
703+ bool > = true >
702704void joint_matrix_load (
703705 Group sg, joint_matrix<S, Use, NumRows, NumCols, Layout, Group> &res,
704706 multi_ptr<T, Space> src, size_t stride) {
0 commit comments