1313#include " dispatch_utils.h"
1414#include " quantization/fp8/common.cuh"
1515
16- #if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__))
17- #define __HIP__MI300_MI250__
16+ #if defined(__HIPCC__) && \
17+ (defined (__gfx90a__) || defined (__gfx942__) || defined (__gfx950__))
18+ #define __HIP__GFX9__
1819#endif
1920
20- #if defined(__HIPCC__) && defined(__gfx942__)
21- #define __HIP__MI300__
21+ #if defined(__HIPCC__) && ( defined(__gfx942__) || defined(__gfx950__) )
22+ #define __HIP__MI3XX__
2223#endif
2324
25+ #if defined(__gfx950__)
26+ #define LDS_SIZE 160 * 1024
27+ #else
28+ #define LDS_SIZE 64 * 1024
29+ #endif
30+
31+ static int get_lds_size () {
32+ auto dprops = at::cuda::getCurrentDeviceProperties ();
33+ std::string device_arch = dprops->gcnArchName ;
34+ size_t substring = device_arch.find (" gfx95" );
35+ if (substring == std::string::npos) return 64 * 1024 ;
36+ return 160 * 1024 ;
37+ }
38+
2439#if defined(NDEBUG)
2540 #undef NDEBUG
2641 #include < assert.h>
@@ -267,15 +282,16 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
267282 V0 += (s.x + s.y ); \
268283 }
269284
270- #if defined(__HIP__MI300_MI250__ ) // TODO: Add NAVI support
285+ #if defined(__HIP__GFX9__ ) // TODO: Add NAVI support
271286// This version targets cases where A[] fits LDS capacity
272287template <typename scalar_t , int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
273288 int UNRL, int N>
274289__global__ void __launch_bounds__ (WvPrGrp* THRDS)
275290 wvSplitK_hf_sml_(const int K, const int M, const scalar_t * B,
276291 const scalar_t * __restrict__ A, scalar_t * C,
277292 const int _WvPrGrp, const int CuCount) {
278- #if defined(__HIP__MI300__)
293+ constexpr int max_lds_len = LDS_SIZE / 2 ;
294+ #if defined(__HIP__MI3XX__)
279295 constexpr bool use_mfma = (std::is_same_v<scalar_t , __hip_bfloat16>);
280296 #else
281297 constexpr bool use_mfma = false ;
@@ -295,13 +311,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
295311 };
296312
297313 // ----------------------------------------------------
298- // Reserving 64 KB of LDS to have 1 WG / CU
314+ // Reserving 64/160 KB of LDS to have 1 WG / CU
299315 // Goal is to bring the activation matrix A to the LDS
300316 // and use it across the lifetime of the work group
301317 // TODO: When activation matrix is larger than 64 KB
302318 // then this is not goint to work!
303319 // ----------------------------------------------------
304- __shared__ scalar_t s[1024 * 32 ];
320+ __shared__ scalar_t s[max_lds_len ];
305321
306322 // ----------------------------------------------------
307323 // Fetch the activation matrix to LDS
@@ -312,11 +328,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
312328 // - Then the WG will move to another 8 K elements
313329 // TODO: Logic below will only work when K is multiple of 8
314330 // ----------------------------------------------------
315- for (uint32_t k = 0 ; k < min (K * N, 32 * 1024 );
331+ for (uint32_t k = 0 ; k < min (K * N, max_lds_len );
316332 k += THRDS * WvPrGrp * A_CHUNK) {
317333 uint32_t k_in = k + ((threadIdx .y * THRDS + threadIdx .x ) * A_CHUNK);
318334
319- if (k_in >= min (K * N, 32 * 1024 )) break ;
335+ if (k_in >= min (K * N, max_lds_len )) break ;
320336
321337 *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
322338 }
@@ -517,25 +533,26 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
517533 m += CuCount * _WvPrGrp * YTILE;
518534 }
519535}
520- #else // !defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
536+ #else // !defined(__HIP__GFX9__ ) TODO: Add NAVI support
521537template <typename scalar_t , int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
522538 int UNRL, int N>
523539__global__ void wvSplitK_hf_sml_ (const int K, const int M, const scalar_t * B,
524540 const scalar_t * __restrict__ A, scalar_t * C,
525541 const int _WvPrGrp, const int CuCount) {
526542 UNREACHABLE_CODE
527543}
528- #endif // defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
544+ #endif // defined(__HIP__GFX9__ ) TODO: Add NAVI support
529545
530- #if defined(__HIP__MI300_MI250__ ) // TODO: Add NAVI support
546+ #if defined(__HIP__GFX9__ ) // TODO: Add NAVI support
531547// This version targets cases where A[] marginally exceeds LDS capacity
532548template <typename scalar_t , int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
533549 int UNRL, int N>
534550__global__ void __launch_bounds__ (WvPrGrp* THRDS)
535551 wvSplitK_hf_(const int K, const int M, const scalar_t * B,
536552 const scalar_t * __restrict__ A, scalar_t * C,
537553 const int _WvPrGrp, const int CuCount) {
538- #if defined(__HIP__MI300__)
554+ constexpr int max_lds_len = LDS_SIZE / 2 ;
555+ #if defined(__HIP__MI3XX__)
539556 constexpr bool use_mfma = (std::is_same_v<scalar_t , __hip_bfloat16>);
540557 #else
541558 constexpr bool use_mfma = false ;
@@ -561,7 +578,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
561578 // TODO: When activation matrix is larger than 64 KB
562579 // then this is not goint to work!
563580 // ----------------------------------------------------
564- __shared__ scalar_t s[1024 * 32 ];
581+ __shared__ scalar_t s[max_lds_len ];
565582
566583 // ----------------------------------------------------
567584 // Computation of columns that need to be committed to memory!
@@ -598,11 +615,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
598615 // - Then the WG will move to another 8 K elements
599616 // TODO: Logic below will only work when K is multiple of 8
600617 // ----------------------------------------------------
601- for (uint32_t k = 0 ; k < min (K * N, 32 * 1024 );
618+ for (uint32_t k = 0 ; k < min (K * N, max_lds_len );
602619 k += THRDS * WvPrGrp * A_CHUNK) {
603620 uint32_t k_in = k + ((threadIdx .y * THRDS + threadIdx .x ) * A_CHUNK);
604621
605- if (k_in >= min (K * N, 32 * 1024 )) break ;
622+ if (k_in >= min (K * N, max_lds_len )) break ;
606623
607624 *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
608625 }
@@ -686,7 +703,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
686703 // Fetch A activation matrix in interleaved fashion from LDS or memory
687704
688705 for (int n = 0 ; n < N; n++) {
689- if (k_ + K * n < 32 * 1024 )
706+ if (k_ + K * n < max_lds_len )
690707 bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
691708 else
692709 bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
@@ -817,25 +834,26 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
817834 }
818835}
819836
820- #else // !defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
837+ #else // !defined(__HIP__GFX9__ ) TODO: Add NAVI support
821838template <typename scalar_t , int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
822839 int UNRL, int N>
823840__global__ void wvSplitK_hf_ (const int K, const int M, const scalar_t * B,
824841 const scalar_t * __restrict__ A, scalar_t * C,
825842 const int _WvPrGrp, const int CuCount) {
826843 UNREACHABLE_CODE
827844}
828- #endif // defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
845+ #endif // defined(__HIP__GFX9__ ) TODO: Add NAVI support
829846
830- #if defined(__HIP__MI300_MI250__ ) // TODO: Add NAVI support
847+ #if defined(__HIP__GFX9__ ) // TODO: Add NAVI support
831848// This version targets big A[] cases, where it is much larger than LDS capacity
832849template <typename scalar_t , int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
833850 int UNRL, int N>
834851__global__ void __launch_bounds__ (WvPrGrp* THRDS)
835852 wvSplitK_hf_big_(const int K, const int M, const scalar_t * B,
836853 const scalar_t * __restrict__ A, scalar_t * C,
837854 const int _WvPrGrp, const int CuCount) {
838- #if defined(__HIP__MI300__)
855+ constexpr int max_lds_len = LDS_SIZE / 2 ;
856+ #if defined(__HIP__MI3XX__)
839857 constexpr bool use_mfma = (std::is_same_v<scalar_t , __hip_bfloat16>);
840858 #else
841859 constexpr bool use_mfma = false ;
@@ -855,13 +873,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
855873 };
856874
857875 // ----------------------------------------------------
858- // Reserving 64 KB of LDS to have 1 WG / CU
876+ // Reserving 64/160 KB of LDS to have 1 WG / CU
859877 // Goal is to bring the activation matrix A to the LDS
860878 // and use it across the lifetime of the work group
861879 // TODO: When activation matrix is larger than 64 KB
862880 // then this is not goint to work!
863881 // ----------------------------------------------------
864- __shared__ scalar_t s[1024 * 32 ];
882+ __shared__ scalar_t s[max_lds_len ];
865883
866884 // ----------------------------------------------------
867885 // Computation of columns that need to be committed to memory!
@@ -902,11 +920,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
902920 // ----------------------------------------------------
903921 #define PCML
904922 #ifndef PCML
905- for (uint32_t k = 0 ; k < min (K * N, 32 * 1024 );
923+ for (uint32_t k = 0 ; k < min (K * N, max_lds_len );
906924 k += THRDS * WvPrGrp * A_CHUNK) {
907925 uint32_t k_in = k + ((threadIdx .y * THRDS + threadIdx .x ) * A_CHUNK);
908926
909- if (k_in >= min (K * N, 32 * 1024 )) break ;
927+ if (k_in >= min (K * N, max_lds_len )) break ;
910928
911929 *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
912930 }
@@ -916,7 +934,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
916934 #define TUC (THRDS * UNRL * A_CHUNK)
917935 uint32_t kBase = 0 ;
918936 // find biggest k size that fits in LDS
919- uint32_t kFit = (32 * 1024 ) / N;
937+ uint32_t kFit = (max_lds_len ) / N;
920938 // kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple
921939 // of TUC
922940 kFit = (kFit % TUC == 0 )
@@ -1164,15 +1182,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
11641182 }
11651183 }
11661184}
1167- #else // !defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
1185+ #else // !defined(__HIP__GFX9__ ) TODO: Add NAVI support
11681186template <typename scalar_t , int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
11691187 int UNRL, int N>
11701188__global__ void wvSplitK_hf_big_ (const int K, const int M, const scalar_t * B,
11711189 const scalar_t * __restrict__ A, scalar_t * C,
11721190 const int _WvPrGrp, const int CuCount) {
11731191 UNREACHABLE_CODE
11741192}
1175- #endif // defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
1193+ #endif // defined(__HIP__GFX9__ ) TODO: Add NAVI support
11761194
11771195int mindiv (int N, int div1, int div2) {
11781196 int nPrRnd = div1 * div2;
@@ -1222,17 +1240,18 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
12221240
12231241 const at::cuda::OptionalCUDAGuard device_guard (device_of (in_a));
12241242 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
1243+ const int max_lds_len = get_lds_size () / 2 ;
12251244
12261245#define WVSPLITK (_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
12271246 _N) \
12281247 { \
12291248 dim3 block (64 , _WvPrGrp); \
1230- if ((K_in * N_in <= 32 * 1024 ) && (M_in % _YTILEs == 0 )) { \
1249+ if ((K_in * N_in <= max_lds_len ) && (M_in % _YTILEs == 0 )) { \
12311250 int __wvPrGrp = mindiv (M_in, CuCount * _YTILEs, _WvPrGrp); \
12321251 wvSplitK_hf_sml_<fptype, 64 , _YTILEs, _WvPrGrp, 8 , _UNRLs, _N> \
12331252 <<<grid, block, 0 , stream>>> (K_in, M_in, af4, bf4, c, __wvPrGrp, \
12341253 CuCount); \
1235- } else if (K_in * N_in <= 32 * 1024 * 1.2 ) { \
1254+ } else if (K_in * N_in <= max_lds_len * 1.2 ) { \
12361255 int __wvPrGrp = mindiv (M_in, CuCount * _YTILEm, _WvPrGrp); \
12371256 wvSplitK_hf_<fptype, 64 , _YTILEm, _WvPrGrp, 8 , _UNRLm, _N> \
12381257 <<<grid, block, 0 , stream>>> (K_in, M_in, af4, bf4, c, __wvPrGrp, \
@@ -1272,7 +1291,7 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
12721291 return out_c;
12731292}
12741293
1275- #if defined(__HIP__MI300__ ) // TODO: Add NAVI support
1294+ #if defined(__HIP__MI3XX__ ) // TODO: Add NAVI support
12761295template <typename scalar_t , typename fp8_t , int THRDS, int YTILE, int WvPrGrp,
12771296 int A_CHUNK, int UNRL, int N>
12781297__global__ void __launch_bounds__ (WvPrGrp* THRDS)
@@ -1281,6 +1300,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
12811300 const float * __restrict__ s_A,
12821301 const float * __restrict__ s_B, const int _WvPrGrp,
12831302 const int CuCount) {
1303+ constexpr int max_lds_len = LDS_SIZE;
12841304 using scalar8 =
12851305 __attribute__ ((__vector_size__ ((A_CHUNK / 4 ) * sizeof (float )))) float ;
12861306 using intx2 = __attribute__ ((__vector_size__ (2 * sizeof (int )))) int ;
@@ -1296,10 +1316,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
12961316 scalar8 h8;
12971317 };
12981318
1299- __shared__ fp8_t s[1024 * 64 ];
1319+ __shared__ fp8_t s[max_lds_len ];
13001320
13011321 for (uint32_t k = (threadIdx .y * THRDS + threadIdx .x ) * A_CHUNK;
1302- k < min (K * N, 64 * 1024 ); k += THRDS * WvPrGrp * A_CHUNK) {
1322+ k < min (K * N, max_lds_len ); k += THRDS * WvPrGrp * A_CHUNK) {
13031323 *((bigType*)(&s[k])) = *((bigType*)(&A[k]));
13041324 }
13051325 __syncthreads ();
@@ -1436,7 +1456,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
14361456 m += CuCount * _WvPrGrp * YTILE;
14371457 }
14381458}
1439- #else // !defined(__HIP__MI300__ ) TODO: Add NAVI support
1459+ #else // !defined(__HIP__MI3XX__ ) TODO: Add NAVI support
14401460template <typename scalar_t , typename fp8_t , int THRDS, int YTILE, int WvPrGrp,
14411461 int A_CHUNK, int UNRL, int N>
14421462__global__ void wvSplitKQ_hf_sml_ (const int K, const int Kp, const int M,
@@ -1446,16 +1466,17 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
14461466 const int _WvPrGrp, const int CuCount) {
14471467 UNREACHABLE_CODE
14481468}
1449- #endif // defined(__HIP__MI300__ ) TODO: Add NAVI support
1469+ #endif // defined(__HIP__MI3XX__ ) TODO: Add NAVI support
14501470
1451- #if defined(__HIP__MI300__ ) // TODO: Add NAVI support
1471+ #if defined(__HIP__MI3XX__ ) // TODO: Add NAVI support
14521472template <typename scalar_t , typename fp8_t , int THRDS, int YTILE, int WvPrGrp,
14531473 int A_CHUNK, int UNRL, int N>
14541474__global__ void __launch_bounds__ (WvPrGrp* THRDS)
14551475 wvSplitKQ_hf_ (const int K, const int Kp, const int M, const fp8_t * B,
14561476 const fp8_t * __restrict__ A, scalar_t * C,
14571477 const float * __restrict__ s_A, const float * __restrict__ s_B,
14581478 const int _WvPrGrp, const int CuCount) {
1479+ constexpr int max_lds_len = LDS_SIZE;
14591480 using scalar8 =
14601481 __attribute__ ((__vector_size__ ((A_CHUNK / 4 ) * sizeof (float )))) float ;
14611482 using intx2 = __attribute__ ((__vector_size__ (2 * sizeof (int )))) int ;
@@ -1471,10 +1492,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
14711492 scalar8 h8;
14721493 };
14731494
1474- __shared__ fp8_t s[1024 * 64 ];
1495+ __shared__ fp8_t s[max_lds_len ];
14751496
14761497 for (uint32_t k = (threadIdx .y * THRDS + threadIdx .x ) * A_CHUNK;
1477- k < min (K * N, 64 * 1024 ); k += THRDS * WvPrGrp * A_CHUNK) {
1498+ k < min (K * N, max_lds_len ); k += THRDS * WvPrGrp * A_CHUNK) {
14781499 *((bigType*)(&s[k])) = *((bigType*)(&A[k]));
14791500 }
14801501 __syncthreads ();
@@ -1517,7 +1538,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
15171538 uint32_t k_ = k + threadIdx .x * A_CHUNK;
15181539 if (k_ >= K) break ;
15191540 for (int n = 0 ; n < N; n++) {
1520- if (k_ + K * n < 64 * 1024 )
1541+ if (k_ + K * n < max_lds_len )
15211542 bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
15221543 else
15231544 bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
@@ -1608,7 +1629,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
16081629 m += CuCount * _WvPrGrp * YTILE;
16091630 }
16101631}
1611- #else // !defined(__HIP__MI300__ ) TODO: Add NAVI support
1632+ #else // !defined(__HIP__MI3XX__ ) TODO: Add NAVI support
16121633template <typename scalar_t , typename fp8_t , int THRDS, int YTILE, int WvPrGrp,
16131634 int A_CHUNK, int UNRL, int N>
16141635__global__ void wvSplitKQ_hf_ (const int K, const int Kp, const int M,
@@ -1618,7 +1639,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
16181639 const int CuCount) {
16191640 UNREACHABLE_CODE
16201641}
1621- #endif // defined(__HIP__MI300__ ) TODO: Add NAVI support
1642+ #endif // defined(__HIP__MI3XX__ ) TODO: Add NAVI support
16221643
16231644void wvSplitKQ (at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
16241645 at::Tensor& scale_a, at::Tensor& scale_b,
@@ -1638,12 +1659,13 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
16381659 dim3 grid (CuCount);
16391660 const at::cuda::OptionalCUDAGuard device_guard (device_of (in_a));
16401661 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
1662+ const int max_lds_len = get_lds_size ();
16411663
16421664#define WVSPLITKQ (_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
16431665 _N) \
16441666 { \
16451667 dim3 block (64 , _WvPrGrp); \
1646- if ((K_in * N_in <= 64 * 1024 ) && (M_in % _YTILEs == 0 )) { \
1668+ if ((K_in * N_in <= max_lds_len ) && (M_in % _YTILEs == 0 )) { \
16471669 int __wvPrGrp = mindiv (M_in, CuCount * _YTILEs, _WvPrGrp); \
16481670 wvSplitKQ_hf_sml_<fptype, fp8_t , 64 , _YTILEs, _WvPrGrp, 16 , _UNRLs, _N> \
16491671 <<<grid, block, 0 , stream>>> (K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \
0 commit comments