@@ -173,6 +173,9 @@ namespace ggml_cuda_mma {
173173#elif defined(AMD_WMMA_AVAILABLE)
174174#if defined(RDNA4)
175175 static constexpr int ne = I * J / 32 ;
176+ #elif defined(RDNA3)
177+ static constexpr int ne = (I == 16 && J == 16 ) ? I * J / 32 : I * J / 16 ;
178+ #endif // defined(RDNA4)
176179 T x[ne] = {0 };
177180
178181 static constexpr __device__ bool supported () {
@@ -182,7 +185,11 @@ namespace ggml_cuda_mma {
182185
183186 static __device__ __forceinline__ int get_i (const int l) {
184187 if constexpr (I == 16 && J == 16 ) {
188+ #if defined(RDNA4)
185189 return 8 * (threadIdx .x / 16 ) + l;
190+ #elif defined(RDNA3)
191+ return 2 * l + (threadIdx .x / 16 );
192+ #endif // defined(RDNA4)
186193 } else {
187194 NO_DEVICE_CODE;
188195 return -1 ;
@@ -197,7 +204,6 @@ namespace ggml_cuda_mma {
197204 return -1 ;
198205 }
199206 }
200- #endif
201207#else
202208 static constexpr int ne = I * J / 32 ;
203209 T x[ne] = {0 };
@@ -284,6 +290,7 @@ namespace ggml_cuda_mma {
284290 }
285291 }
286292#elif defined(AMD_WMMA_AVAILABLE)
293+
287294 static constexpr int ne = I * J / 32 ;
288295 half2 x[ne] = {{0 .0f , 0 .0f }};
289296
@@ -544,23 +551,40 @@ namespace ggml_cuda_mma {
544551 } else if constexpr (std::is_same_v<T, int >) {
545552 if constexpr (I == 16 && J == 4 ) {
546553 int64_t * xi = (int64_t *) t.x ;
554+ #if defined(RDNA4)
547555 const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx .x % t.I ) * stride + 2 * (threadIdx .x / t.I ));
548556 xi[0 ] = xs[0 ];
549-
557+ #elif defined(RDNA3)
558+ static_assert (tile<I,J,T>::ne >= 4 , " fragment too small" );
559+ const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx .x % t.I ) * stride);
560+ xi[0 ] = xs[0 ];
561+ xi[1 ] = xs[1 ];
562+ #endif // defined(RDNA4)
550563 }else if constexpr (I == 16 && J == 8 ) {
551564 int64_t * xi = (int64_t *) t.x ;
565+ #if defined(RDNA4)
552566 const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx .x % t.I ) * stride + 4 * (threadIdx .x / t.I ));
553567 xi[0 ] = xs[0 ];
554568
555569 const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx .x % t.I ) * stride + 4 * (threadIdx .x / t.I ) + 2 );
556570 xi[1 ] = xs1[0 ];
571+ #elif defined(RDNA3)
572+ static_assert (tile<I,J,T>::ne >= 8 , " fragment too small" );
573+ const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx .x % t.I ) * stride);
574+ // contiguous four 64-bit chunks per lane for the wider RDNA3 fragment
575+ xi[0 ] = xs[0 ];
576+ xi[1 ] = xs[1 ];
577+ const int64_t * xs1 = xs + 2 ;
578+ xi[2 ] = xs1[0 ];
579+ xi[3 ] = xs1[1 ];
557580
558581 }else {
559582 NO_DEVICE_CODE;
560583 }
561584 } else {
562585 NO_DEVICE_CODE;
563586 }
587+ #endif // defined(RDNA4)
564588#else
565589#pragma unroll
566590 for (int l = 0 ; l < t.ne ; ++l) {
@@ -858,12 +882,14 @@ namespace ggml_cuda_mma {
858882 : " r" (Axi[2 ]), " r" (Axi[3 ]), " r" (Bxi[3 ]));
859883#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
860884#elif defined(AMD_WMMA_AVAILABLE)
885+ #if defined(RDNA4)
861886 using halfx8_t = __attribute__ ((ext_vector_type (8 ))) _Float16;
862887 using floatx8_t = __attribute__ ((ext_vector_type (8 ))) float ;
863888 floatx8_t & acc_frag = reinterpret_cast <floatx8_t &>(D.x [0 ]);
864889 const halfx8_t & a_frag = reinterpret_cast <const halfx8_t &>(A.x [0 ]);
865890 const halfx8_t & b_frag = reinterpret_cast <const halfx8_t &>(B.x [0 ]);
866891 acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12 (a_frag, b_frag, acc_frag);
892+ #endif // RDNA4
867893#else
868894 GGML_UNUSED_VARS (D, A, B);
869895 NO_DEVICE_CODE;
@@ -873,12 +899,14 @@ namespace ggml_cuda_mma {
873899 static __device__ __forceinline__ void mma (
874900 tile<16 , 16 , float > & D, const tile<16 , 8 , nv_bfloat162> & A, const tile<16 , 8 , nv_bfloat162> & B) {
875901#if defined(AMD_WMMA_AVAILABLE)
902+ #if defined(RDNA4)
876903 using bf16x8_t = __attribute__ ((ext_vector_type (8 ))) __bf16;
877904 using floatx8_t = __attribute__ ((ext_vector_type (8 ))) float ;
878905 floatx8_t & acc_frag = reinterpret_cast <floatx8_t &>(D.x [0 ]);
879906 const bf16x8_t & a_frag = reinterpret_cast <const bf16x8_t &>(A.x [0 ]);
880907 const bf16x8_t & b_frag = reinterpret_cast <const bf16x8_t &>(B.x [0 ]);
881908 acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12 (a_frag, b_frag, acc_frag);
909+ #endif // RDNA4
882910#else
883911 GGML_UNUSED_VARS (D, A, B);
884912 NO_DEVICE_CODE;
@@ -907,14 +935,14 @@ namespace ggml_cuda_mma {
907935#endif // defined(CDNA3)
908936
909937#elif defined(AMD_WMMA_AVAILABLE)
910- using int32x2_t = __attribute__ ((__vector_size__ (2 * sizeof (int )))) int ;
911- int32x2_t * a_vec = (int32x2_t *) A.x ;
912- int32x2_t * b_vec = (int32x2_t *) B.x ;
913938
914939 using int32x8_t = __attribute__ ((__vector_size__ (8 * sizeof (int )))) int ;
915940 int32x8_t * acc = (int32x8_t *) D.x ;
916941
917942#if defined(RDNA4)
943+ using int32x2_t = __attribute__ ((__vector_size__ (2 * sizeof (int )))) int ;
944+ int32x2_t * a_vec = (int32x2_t *) A.x ;
945+ int32x2_t * b_vec = (int32x2_t *) B.x ;
918946
919947 acc[0 ] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12 (
920948 true ,
@@ -933,7 +961,30 @@ namespace ggml_cuda_mma {
933961 acc[0 ],
934962 true
935963 );
936- #endif // defined(RDNA4)
964+
965+ #elif defined(RDNA3)
966+ using int32x4_t = __attribute__ ((__vector_size__ (4 * sizeof (int )))) int ;
967+ int32x4_t * a_vec = (int32x4_t *) A.x ;
968+ int32x4_t * b_vec = (int32x4_t *) B.x ;
969+
970+ acc[0 ] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32 (
971+ true ,
972+ a_vec[0 ],
973+ true ,
974+ b_vec[0 ],
975+ acc[0 ],
976+ true
977+ );
978+
979+ acc[0 ] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32 (
980+ true ,
981+ a_vec[1 ],
982+ true ,
983+ b_vec[1 ],
984+ acc[0 ],
985+ true
986+ );
987+ #endif // RDNA4
937988
938989#else
939990 GGML_UNUSED_VARS (D, A, B);
@@ -1020,27 +1071,40 @@ namespace ggml_cuda_mma {
10201071static __device__ __forceinline__ void mma (
10211072 tile<16 , 16 , int > & D, const tile<16 , 4 , int > & A, const tile<16 , 4 , int > & B) {
10221073#if defined(AMD_WMMA_AVAILABLE)
1023- using int32x2_t = __attribute__ ((__vector_size__ (2 * sizeof (int )))) int ;
1024- int32x2_t * a_vec = (int32x2_t *) A.x ;
1025- int32x2_t * b_vec = (int32x2_t *) B.x ;
1026-
1027- using int32x8_t = __attribute__ ((__vector_size__ (8 * sizeof (int )))) int ;
1028- int32x8_t * acc = (int32x8_t *) D.x ;
1029-
1030- acc[0 ] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12 (
1031- true ,
1032- a_vec[0 ],
1033- true ,
1034- b_vec[0 ],
1035- acc[0 ],
1036- false
1037- );
1074+ using int32x8_t = __attribute__ ((__vector_size__ (8 * sizeof (int )))) int ;
1075+ int32x8_t * acc = (int32x8_t *) D.x ;
1076+ #if defined(RDNA4)
1077+ using int32x2_t = __attribute__ ((__vector_size__ (2 * sizeof (int )))) int ;
1078+ int32x2_t * a_vec = (int32x2_t *) A.x ;
1079+ int32x2_t * b_vec = (int32x2_t *) B.x ;
1080+
1081+ acc[0 ] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12 (
1082+ true ,
1083+ a_vec[0 ],
1084+ true ,
1085+ b_vec[0 ],
1086+ acc[0 ],
1087+ false
1088+ );
1089+ #elif defined(RDNA3)
1090+ using int32x4_t = __attribute__ ((__vector_size__ (4 * sizeof (int )))) int ;
1091+ int32x4_t * a_vec = (int32x4_t *) A.x ;
1092+ int32x4_t * b_vec = (int32x4_t *) B.x ;
1093+
1094+ acc[0 ] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32 (
1095+ true ,
1096+ a_vec[0 ],
1097+ true ,
1098+ b_vec[0 ],
1099+ acc[0 ],
1100+ false
1101+ );
1102+ #endif // RDNA4
10381103#else
10391104 GGML_UNUSED (D);
10401105 GGML_UNUSED (A);
10411106 GGML_UNUSED (B);
10421107 NO_DEVICE_CODE;
1043- #endif
1108+ #endif // AMD_WMMA_AVAILABLE
10441109 }
10451110}
1046-
0 commit comments