1313#include " dispatch_utils.h"
1414#include " quantization/fp8/common.cuh"
1515
16- #if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__))
17- #define __HIP__MI300_MI250__
16+ #if defined(__HIPCC__) && \
17+ (defined (__gfx90a__) || defined (__gfx942__) || defined (__gfx950__))
18+ #define __HIP__GFX9__
1819#endif
1920
20- #if defined(__HIPCC__) && defined(__gfx942__)
21- #define __HIP__MI300__
21+ #if defined(__HIPCC__) && ( defined(__gfx942__) || defined(__gfx950__) )
22+ #define __HIP__MI3XX__
2223#endif
2324
25+ #if defined(__gfx950__)
26+ #define LDS_SIZE 160 * 1024
27+ #else
28+ #define LDS_SIZE 64 * 1024
29+ #endif
30+
31+ int get_lds_size () {
32+ static bool is_cached = false ;
33+ static int result;
34+ if (is_cached == false ) {
35+ auto dprops = at::cuda::getCurrentDeviceProperties ();
36+ std::string device_arch = dprops->gcnArchName ;
37+ size_t substring = device_arch.find (" gfx95" );
38+ result = (substring == std::string::npos ? 64 * 1024 : 160 * 1024 );
39+ is_cached = true ;
40+ }
41+ return result;
42+ }
43+
2444#if defined(NDEBUG)
2545 #undef NDEBUG
2646 #include < assert.h>
@@ -267,15 +287,16 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
267287 V0 += (s.x + s.y ); \
268288 }
269289
270- #if defined(__HIP__MI300_MI250__ ) // TODO: Add NAVI support
290+ #if defined(__HIP__GFX9__ ) // TODO: Add NAVI support
271291// This version targets cases where A[] fits LDS capacity
272292template <typename scalar_t , int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
273293 int UNRL, int N>
274294__global__ void __launch_bounds__ (WvPrGrp* THRDS)
275295 wvSplitK_hf_sml_(const int K, const int M, const scalar_t * B,
276296 const scalar_t * __restrict__ A, scalar_t * C,
277297 const int _WvPrGrp, const int CuCount) {
278- #if defined(__HIP__MI300__)
298+ constexpr int max_lds_len = LDS_SIZE / 2 ;
299+ #if defined(__HIP__MI3XX__)
279300 constexpr bool use_mfma = (std::is_same_v<scalar_t , __hip_bfloat16>);
280301 #else
281302 constexpr bool use_mfma = false ;
@@ -295,13 +316,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
295316 };
296317
297318 // ----------------------------------------------------
298- // Reserving 64 KB of LDS to have 1 WG / CU
319+ // Reserving 64/160 KB of LDS to have 1 WG / CU
299320 // Goal is to bring the activation matrix A to the LDS
300321 // and use it across the lifetime of the work group
301322 // TODO: When activation matrix is larger than 64 KB
302323 // then this is not goint to work!
303324 // ----------------------------------------------------
304- __shared__ scalar_t s[1024 * 32 ];
325+ __shared__ scalar_t s[max_lds_len ];
305326
306327 // ----------------------------------------------------
307328 // Fetch the activation matrix to LDS
@@ -312,11 +333,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
312333 // - Then the WG will move to another 8 K elements
313334 // TODO: Logic below will only work when K is multiple of 8
314335 // ----------------------------------------------------
315- for (uint32_t k = 0 ; k < min (K * N, 32 * 1024 );
336+ for (uint32_t k = 0 ; k < min (K * N, max_lds_len );
316337 k += THRDS * WvPrGrp * A_CHUNK) {
317338 uint32_t k_in = k + ((threadIdx .y * THRDS + threadIdx .x ) * A_CHUNK);
318339
319- if (k_in >= min (K * N, 32 * 1024 )) break ;
340+ if (k_in >= min (K * N, max_lds_len )) break ;
320341
321342 *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
322343 }
@@ -517,25 +538,26 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
517538 m += CuCount * _WvPrGrp * YTILE;
518539 }
519540}
520- #else // !defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
541+ #else // !defined(__HIP__GFX9__ ) TODO: Add NAVI support
521542template <typename scalar_t , int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
522543 int UNRL, int N>
523544__global__ void wvSplitK_hf_sml_ (const int K, const int M, const scalar_t * B,
524545 const scalar_t * __restrict__ A, scalar_t * C,
525546 const int _WvPrGrp, const int CuCount) {
526547 UNREACHABLE_CODE
527548}
528- #endif // defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
549+ #endif // defined(__HIP__GFX9__ ) TODO: Add NAVI support
529550
530- #if defined(__HIP__MI300_MI250__ ) // TODO: Add NAVI support
551+ #if defined(__HIP__GFX9__ ) // TODO: Add NAVI support
531552// This version targets cases where A[] marginally exceeds LDS capacity
532553template <typename scalar_t , int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
533554 int UNRL, int N>
534555__global__ void __launch_bounds__ (WvPrGrp* THRDS)
535556 wvSplitK_hf_(const int K, const int M, const scalar_t * B,
536557 const scalar_t * __restrict__ A, scalar_t * C,
537558 const int _WvPrGrp, const int CuCount) {
538- #if defined(__HIP__MI300__)
559+ constexpr int max_lds_len = LDS_SIZE / 2 ;
560+ #if defined(__HIP__MI3XX__)
539561 constexpr bool use_mfma = (std::is_same_v<scalar_t , __hip_bfloat16>);
540562 #else
541563 constexpr bool use_mfma = false ;
@@ -561,7 +583,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
561583 // TODO: When activation matrix is larger than 64 KB
562584 // then this is not goint to work!
563585 // ----------------------------------------------------
564- __shared__ scalar_t s[1024 * 32 ];
586+ __shared__ scalar_t s[max_lds_len ];
565587
566588 // ----------------------------------------------------
567589 // Computation of columns that need to be committed to memory!
@@ -598,11 +620,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
598620 // - Then the WG will move to another 8 K elements
599621 // TODO: Logic below will only work when K is multiple of 8
600622 // ----------------------------------------------------
601- for (uint32_t k = 0 ; k < min (K * N, 32 * 1024 );
623+ for (uint32_t k = 0 ; k < min (K * N, max_lds_len );
602624 k += THRDS * WvPrGrp * A_CHUNK) {
603625 uint32_t k_in = k + ((threadIdx .y * THRDS + threadIdx .x ) * A_CHUNK);
604626
605- if (k_in >= min (K * N, 32 * 1024 )) break ;
627+ if (k_in >= min (K * N, max_lds_len )) break ;
606628
607629 *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
608630 }
@@ -686,7 +708,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
686708 // Fetch A activation matrix in interleaved fashion from LDS or memory
687709
688710 for (int n = 0 ; n < N; n++) {
689- if (k_ + K * n < 32 * 1024 )
711+ if (k_ + K * n < max_lds_len )
690712 bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
691713 else
692714 bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
@@ -817,25 +839,26 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
817839 }
818840}
819841
820- #else // !defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
842+ #else // !defined(__HIP__GFX9__ ) TODO: Add NAVI support
821843template <typename scalar_t , int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
822844 int UNRL, int N>
823845__global__ void wvSplitK_hf_ (const int K, const int M, const scalar_t * B,
824846 const scalar_t * __restrict__ A, scalar_t * C,
825847 const int _WvPrGrp, const int CuCount) {
826848 UNREACHABLE_CODE
827849}
828- #endif // defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
850+ #endif // defined(__HIP__GFX9__ ) TODO: Add NAVI support
829851
830- #if defined(__HIP__MI300_MI250__ ) // TODO: Add NAVI support
852+ #if defined(__HIP__GFX9__ ) // TODO: Add NAVI support
831853// This version targets big A[] cases, where it is much larger than LDS capacity
832854template <typename scalar_t , int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
833855 int UNRL, int N>
834856__global__ void __launch_bounds__ (WvPrGrp* THRDS)
835857 wvSplitK_hf_big_(const int K, const int M, const scalar_t * B,
836858 const scalar_t * __restrict__ A, scalar_t * C,
837859 const int _WvPrGrp, const int CuCount) {
838- #if defined(__HIP__MI300__)
860+ constexpr int max_lds_len = LDS_SIZE / 2 ;
861+ #if defined(__HIP__MI3XX__)
839862 constexpr bool use_mfma = (std::is_same_v<scalar_t , __hip_bfloat16>);
840863 #else
841864 constexpr bool use_mfma = false ;
@@ -855,13 +878,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
855878 };
856879
857880 // ----------------------------------------------------
858- // Reserving 64 KB of LDS to have 1 WG / CU
881+ // Reserving 64/160 KB of LDS to have 1 WG / CU
859882 // Goal is to bring the activation matrix A to the LDS
860883 // and use it across the lifetime of the work group
861884 // TODO: When activation matrix is larger than 64 KB
862885 // then this is not goint to work!
863886 // ----------------------------------------------------
864- __shared__ scalar_t s[1024 * 32 ];
887+ __shared__ scalar_t s[max_lds_len ];
865888
866889 // ----------------------------------------------------
867890 // Computation of columns that need to be committed to memory!
@@ -902,11 +925,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
902925 // ----------------------------------------------------
903926 #define PCML
904927 #ifndef PCML
905- for (uint32_t k = 0 ; k < min (K * N, 32 * 1024 );
928+ for (uint32_t k = 0 ; k < min (K * N, max_lds_len );
906929 k += THRDS * WvPrGrp * A_CHUNK) {
907930 uint32_t k_in = k + ((threadIdx .y * THRDS + threadIdx .x ) * A_CHUNK);
908931
909- if (k_in >= min (K * N, 32 * 1024 )) break ;
932+ if (k_in >= min (K * N, max_lds_len )) break ;
910933
911934 *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
912935 }
@@ -916,7 +939,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
916939 #define TUC (THRDS * UNRL * A_CHUNK)
917940 uint32_t kBase = 0 ;
918941 // find biggest k size that fits in LDS
919- uint32_t kFit = (32 * 1024 ) / N;
942+ uint32_t kFit = (max_lds_len ) / N;
920943 // kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple
921944 // of TUC
922945 kFit = (kFit % TUC == 0 )
@@ -1164,15 +1187,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
11641187 }
11651188 }
11661189}
1167- #else // !defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
1190+ #else // !defined(__HIP__GFX9__ ) TODO: Add NAVI support
11681191template <typename scalar_t , int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
11691192 int UNRL, int N>
11701193__global__ void wvSplitK_hf_big_ (const int K, const int M, const scalar_t * B,
11711194 const scalar_t * __restrict__ A, scalar_t * C,
11721195 const int _WvPrGrp, const int CuCount) {
11731196 UNREACHABLE_CODE
11741197}
1175- #endif // defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
1198+ #endif // defined(__HIP__GFX9__ ) TODO: Add NAVI support
11761199
11771200int mindiv (int N, int div1, int div2) {
11781201 int nPrRnd = div1 * div2;
@@ -1222,17 +1245,18 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
12221245
12231246 const at::cuda::OptionalCUDAGuard device_guard (device_of (in_a));
12241247 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
1248+ const int max_lds_len = get_lds_size () / 2 ;
12251249
12261250#define WVSPLITK (_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
12271251 _N) \
12281252 { \
12291253 dim3 block (64 , _WvPrGrp); \
1230- if ((K_in * N_in <= 32 * 1024 ) && (M_in % _YTILEs == 0 )) { \
1254+ if ((K_in * N_in <= max_lds_len ) && (M_in % _YTILEs == 0 )) { \
12311255 int __wvPrGrp = mindiv (M_in, CuCount * _YTILEs, _WvPrGrp); \
12321256 wvSplitK_hf_sml_<fptype, 64 , _YTILEs, _WvPrGrp, 8 , _UNRLs, _N> \
12331257 <<<grid, block, 0 , stream>>> (K_in, M_in, af4, bf4, c, __wvPrGrp, \
12341258 CuCount); \
1235- } else if (K_in * N_in <= 32 * 1024 * 1.2 ) { \
1259+ } else if (K_in * N_in <= max_lds_len * 1.2 ) { \
12361260 int __wvPrGrp = mindiv (M_in, CuCount * _YTILEm, _WvPrGrp); \
12371261 wvSplitK_hf_<fptype, 64 , _YTILEm, _WvPrGrp, 8 , _UNRLm, _N> \
12381262 <<<grid, block, 0 , stream>>> (K_in, M_in, af4, bf4, c, __wvPrGrp, \
@@ -1272,7 +1296,7 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
12721296 return out_c;
12731297}
12741298
1275- #if defined(__HIP__MI300__ ) // TODO: Add NAVI support
1299+ #if defined(__HIP__MI3XX__ ) // TODO: Add NAVI support
12761300template <typename scalar_t , typename fp8_t , int THRDS, int YTILE, int WvPrGrp,
12771301 int A_CHUNK, int UNRL, int N>
12781302__global__ void __launch_bounds__ (WvPrGrp* THRDS)
@@ -1281,6 +1305,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
12811305 const float * __restrict__ s_A,
12821306 const float * __restrict__ s_B, const int _WvPrGrp,
12831307 const int CuCount) {
1308+ constexpr int max_lds_len = LDS_SIZE;
12841309 using scalar8 =
12851310 __attribute__ ((__vector_size__ ((A_CHUNK / 4 ) * sizeof (float )))) float ;
12861311 using intx2 = __attribute__ ((__vector_size__ (2 * sizeof (int )))) int ;
@@ -1296,10 +1321,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
12961321 scalar8 h8;
12971322 };
12981323
1299- __shared__ fp8_t s[1024 * 64 ];
1324+ __shared__ fp8_t s[max_lds_len ];
13001325
13011326 for (uint32_t k = (threadIdx .y * THRDS + threadIdx .x ) * A_CHUNK;
1302- k < min (K * N, 64 * 1024 ); k += THRDS * WvPrGrp * A_CHUNK) {
1327+ k < min (K * N, max_lds_len ); k += THRDS * WvPrGrp * A_CHUNK) {
13031328 *((bigType*)(&s[k])) = *((bigType*)(&A[k]));
13041329 }
13051330 __syncthreads ();
@@ -1436,7 +1461,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
14361461 m += CuCount * _WvPrGrp * YTILE;
14371462 }
14381463}
1439- #else // !defined(__HIP__MI300__ ) TODO: Add NAVI support
1464+ #else // !defined(__HIP__MI3XX__ ) TODO: Add NAVI support
14401465template <typename scalar_t , typename fp8_t , int THRDS, int YTILE, int WvPrGrp,
14411466 int A_CHUNK, int UNRL, int N>
14421467__global__ void wvSplitKQ_hf_sml_ (const int K, const int Kp, const int M,
@@ -1446,16 +1471,17 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
14461471 const int _WvPrGrp, const int CuCount) {
14471472 UNREACHABLE_CODE
14481473}
1449- #endif // defined(__HIP__MI300__ ) TODO: Add NAVI support
1474+ #endif // defined(__HIP__MI3XX__ ) TODO: Add NAVI support
14501475
1451- #if defined(__HIP__MI300__ ) // TODO: Add NAVI support
1476+ #if defined(__HIP__MI3XX__ ) // TODO: Add NAVI support
14521477template <typename scalar_t , typename fp8_t , int THRDS, int YTILE, int WvPrGrp,
14531478 int A_CHUNK, int UNRL, int N>
14541479__global__ void __launch_bounds__ (WvPrGrp* THRDS)
14551480 wvSplitKQ_hf_ (const int K, const int Kp, const int M, const fp8_t * B,
14561481 const fp8_t * __restrict__ A, scalar_t * C,
14571482 const float * __restrict__ s_A, const float * __restrict__ s_B,
14581483 const int _WvPrGrp, const int CuCount) {
1484+ constexpr int max_lds_len = LDS_SIZE;
14591485 using scalar8 =
14601486 __attribute__ ((__vector_size__ ((A_CHUNK / 4 ) * sizeof (float )))) float ;
14611487 using intx2 = __attribute__ ((__vector_size__ (2 * sizeof (int )))) int ;
@@ -1471,10 +1497,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
14711497 scalar8 h8;
14721498 };
14731499
1474- __shared__ fp8_t s[1024 * 64 ];
1500+ __shared__ fp8_t s[max_lds_len ];
14751501
14761502 for (uint32_t k = (threadIdx .y * THRDS + threadIdx .x ) * A_CHUNK;
1477- k < min (K * N, 64 * 1024 ); k += THRDS * WvPrGrp * A_CHUNK) {
1503+ k < min (K * N, max_lds_len ); k += THRDS * WvPrGrp * A_CHUNK) {
14781504 *((bigType*)(&s[k])) = *((bigType*)(&A[k]));
14791505 }
14801506 __syncthreads ();
@@ -1517,7 +1543,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
15171543 uint32_t k_ = k + threadIdx .x * A_CHUNK;
15181544 if (k_ >= K) break ;
15191545 for (int n = 0 ; n < N; n++) {
1520- if (k_ + K * n < 64 * 1024 )
1546+ if (k_ + K * n < max_lds_len )
15211547 bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
15221548 else
15231549 bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
@@ -1608,7 +1634,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
16081634 m += CuCount * _WvPrGrp * YTILE;
16091635 }
16101636}
1611- #else // !defined(__HIP__MI300__ ) TODO: Add NAVI support
1637+ #else // !defined(__HIP__MI3XX__ ) TODO: Add NAVI support
16121638template <typename scalar_t , typename fp8_t , int THRDS, int YTILE, int WvPrGrp,
16131639 int A_CHUNK, int UNRL, int N>
16141640__global__ void wvSplitKQ_hf_ (const int K, const int Kp, const int M,
@@ -1618,7 +1644,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
16181644 const int CuCount) {
16191645 UNREACHABLE_CODE
16201646}
1621- #endif // defined(__HIP__MI300__ ) TODO: Add NAVI support
1647+ #endif // defined(__HIP__MI3XX__ ) TODO: Add NAVI support
16221648
16231649void wvSplitKQ (at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
16241650 at::Tensor& scale_a, at::Tensor& scale_b,
@@ -1638,12 +1664,13 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
16381664 dim3 grid (CuCount);
16391665 const at::cuda::OptionalCUDAGuard device_guard (device_of (in_a));
16401666 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
1667+ const int max_lds_len = get_lds_size ();
16411668
16421669#define WVSPLITKQ (_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
16431670 _N) \
16441671 { \
16451672 dim3 block (64 , _WvPrGrp); \
1646- if ((K_in * N_in <= 64 * 1024 ) && (M_in % _YTILEs == 0 )) { \
1673+ if ((K_in * N_in <= max_lds_len ) && (M_in % _YTILEs == 0 )) { \
16471674 int __wvPrGrp = mindiv (M_in, CuCount * _YTILEs, _WvPrGrp); \
16481675 wvSplitKQ_hf_sml_<fptype, fp8_t , 64 , _YTILEs, _WvPrGrp, 16 , _UNRLs, _N> \
16491676 <<<grid, block, 0 , stream>>> (K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \
0 commit comments