Skip to content

Commit 628db8d

Browse files
committed
Merge remote-tracking branch 'origin/amd/gfx950_skinny_gemm' into upstream_merge_2025_05_29
2 parents 421c498 + 0286875 commit 628db8d

File tree

5 files changed

+91
-54
lines changed

5 files changed

+91
-54
lines changed

csrc/rocm/skinny_gemms.cu

Lines changed: 70 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,34 @@
1313
#include "dispatch_utils.h"
1414
#include "quantization/fp8/common.cuh"
1515

16-
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__))
17-
#define __HIP__MI300_MI250__
16+
#if defined(__HIPCC__) && \
17+
(defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
18+
#define __HIP__GFX9__
1819
#endif
1920

20-
#if defined(__HIPCC__) && defined(__gfx942__)
21-
#define __HIP__MI300__
21+
#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__))
22+
#define __HIP__MI3XX__
2223
#endif
2324

25+
#if defined(__gfx950__)
26+
#define LDS_SIZE 160 * 1024
27+
#else
28+
#define LDS_SIZE 64 * 1024
29+
#endif
30+
31+
int get_lds_size() {
32+
static bool is_cached = false;
33+
static int result;
34+
if (is_cached == false) {
35+
auto dprops = at::cuda::getCurrentDeviceProperties();
36+
std::string device_arch = dprops->gcnArchName;
37+
size_t substring = device_arch.find("gfx95");
38+
result = (substring == std::string::npos ? 64 * 1024 : 160 * 1024);
39+
is_cached = true;
40+
}
41+
return result;
42+
}
43+
2444
#if defined(NDEBUG)
2545
#undef NDEBUG
2646
#include <assert.h>
@@ -267,15 +287,16 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
267287
V0 += (s.x + s.y); \
268288
}
269289

270-
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
290+
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
271291
// This version targets cases where A[] fits LDS capacity
272292
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
273293
int UNRL, int N>
274294
__global__ void __launch_bounds__(WvPrGrp* THRDS)
275295
wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
276296
const scalar_t* __restrict__ A, scalar_t* C,
277297
const int _WvPrGrp, const int CuCount) {
278-
#if defined(__HIP__MI300__)
298+
constexpr int max_lds_len = LDS_SIZE / 2;
299+
#if defined(__HIP__MI3XX__)
279300
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
280301
#else
281302
constexpr bool use_mfma = false;
@@ -295,13 +316,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
295316
};
296317

297318
//----------------------------------------------------
298-
// Reserving 64 KB of LDS to have 1 WG / CU
319+
// Reserving 64/160 KB of LDS to have 1 WG / CU
299320
// Goal is to bring the activation matrix A to the LDS
300321
// and use it across the lifetime of the work group
301322
// TODO: When activation matrix is larger than 64 KB
302323
// then this is not goint to work!
303324
//----------------------------------------------------
304-
__shared__ scalar_t s[1024 * 32];
325+
__shared__ scalar_t s[max_lds_len];
305326

306327
//----------------------------------------------------
307328
// Fetch the activation matrix to LDS
@@ -312,11 +333,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
312333
// - Then the WG will move to another 8 K elements
313334
// TODO: Logic below will only work when K is multiple of 8
314335
//----------------------------------------------------
315-
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
336+
for (uint32_t k = 0; k < min(K * N, max_lds_len);
316337
k += THRDS * WvPrGrp * A_CHUNK) {
317338
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
318339

319-
if (k_in >= min(K * N, 32 * 1024)) break;
340+
if (k_in >= min(K * N, max_lds_len)) break;
320341

321342
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
322343
}
@@ -517,25 +538,26 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
517538
m += CuCount * _WvPrGrp * YTILE;
518539
}
519540
}
520-
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
541+
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
521542
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
522543
int UNRL, int N>
523544
__global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
524545
const scalar_t* __restrict__ A, scalar_t* C,
525546
const int _WvPrGrp, const int CuCount) {
526547
UNREACHABLE_CODE
527548
}
528-
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
549+
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
529550

530-
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
551+
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
531552
// This version targets cases where A[] marginally exceeds LDS capacity
532553
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
533554
int UNRL, int N>
534555
__global__ void __launch_bounds__(WvPrGrp* THRDS)
535556
wvSplitK_hf_(const int K, const int M, const scalar_t* B,
536557
const scalar_t* __restrict__ A, scalar_t* C,
537558
const int _WvPrGrp, const int CuCount) {
538-
#if defined(__HIP__MI300__)
559+
constexpr int max_lds_len = LDS_SIZE / 2;
560+
#if defined(__HIP__MI3XX__)
539561
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
540562
#else
541563
constexpr bool use_mfma = false;
@@ -561,7 +583,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
561583
// TODO: When activation matrix is larger than 64 KB
562584
// then this is not goint to work!
563585
//----------------------------------------------------
564-
__shared__ scalar_t s[1024 * 32];
586+
__shared__ scalar_t s[max_lds_len];
565587

566588
//----------------------------------------------------
567589
// Computation of columns that need to be committed to memory!
@@ -598,11 +620,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
598620
// - Then the WG will move to another 8 K elements
599621
// TODO: Logic below will only work when K is multiple of 8
600622
//----------------------------------------------------
601-
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
623+
for (uint32_t k = 0; k < min(K * N, max_lds_len);
602624
k += THRDS * WvPrGrp * A_CHUNK) {
603625
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
604626

605-
if (k_in >= min(K * N, 32 * 1024)) break;
627+
if (k_in >= min(K * N, max_lds_len)) break;
606628

607629
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
608630
}
@@ -686,7 +708,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
686708
// Fetch A activation matrix in interleaved fashion from LDS or memory
687709

688710
for (int n = 0; n < N; n++) {
689-
if (k_ + K * n < 32 * 1024)
711+
if (k_ + K * n < max_lds_len)
690712
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
691713
else
692714
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
@@ -817,25 +839,26 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
817839
}
818840
}
819841

820-
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
842+
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
821843
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
822844
int UNRL, int N>
823845
__global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
824846
const scalar_t* __restrict__ A, scalar_t* C,
825847
const int _WvPrGrp, const int CuCount) {
826848
UNREACHABLE_CODE
827849
}
828-
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
850+
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
829851

830-
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
852+
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
831853
// This version targets big A[] cases, where it is much larger than LDS capacity
832854
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
833855
int UNRL, int N>
834856
__global__ void __launch_bounds__(WvPrGrp* THRDS)
835857
wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
836858
const scalar_t* __restrict__ A, scalar_t* C,
837859
const int _WvPrGrp, const int CuCount) {
838-
#if defined(__HIP__MI300__)
860+
constexpr int max_lds_len = LDS_SIZE / 2;
861+
#if defined(__HIP__MI3XX__)
839862
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
840863
#else
841864
constexpr bool use_mfma = false;
@@ -855,13 +878,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
855878
};
856879

857880
//----------------------------------------------------
858-
// Reserving 64 KB of LDS to have 1 WG / CU
881+
// Reserving 64/160 KB of LDS to have 1 WG / CU
859882
// Goal is to bring the activation matrix A to the LDS
860883
// and use it across the lifetime of the work group
861884
// TODO: When activation matrix is larger than 64 KB
862885
// then this is not goint to work!
863886
//----------------------------------------------------
864-
__shared__ scalar_t s[1024 * 32];
887+
__shared__ scalar_t s[max_lds_len];
865888

866889
//----------------------------------------------------
867890
// Computation of columns that need to be committed to memory!
@@ -902,11 +925,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
902925
//----------------------------------------------------
903926
#define PCML
904927
#ifndef PCML
905-
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
928+
for (uint32_t k = 0; k < min(K * N, max_lds_len);
906929
k += THRDS * WvPrGrp * A_CHUNK) {
907930
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
908931

909-
if (k_in >= min(K * N, 32 * 1024)) break;
932+
if (k_in >= min(K * N, max_lds_len)) break;
910933

911934
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
912935
}
@@ -916,7 +939,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
916939
#define TUC (THRDS * UNRL * A_CHUNK)
917940
uint32_t kBase = 0;
918941
// find biggest k size that fits in LDS
919-
uint32_t kFit = (32 * 1024) / N;
942+
uint32_t kFit = (max_lds_len) / N;
920943
// kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple
921944
// of TUC
922945
kFit = (kFit % TUC == 0)
@@ -1164,15 +1187,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
11641187
}
11651188
}
11661189
}
1167-
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
1190+
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
11681191
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
11691192
int UNRL, int N>
11701193
__global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
11711194
const scalar_t* __restrict__ A, scalar_t* C,
11721195
const int _WvPrGrp, const int CuCount) {
11731196
UNREACHABLE_CODE
11741197
}
1175-
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
1198+
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
11761199

11771200
int mindiv(int N, int div1, int div2) {
11781201
int nPrRnd = div1 * div2;
@@ -1222,17 +1245,18 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
12221245

12231246
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
12241247
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
1248+
const int max_lds_len = get_lds_size() / 2;
12251249

12261250
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
12271251
_N) \
12281252
{ \
12291253
dim3 block(64, _WvPrGrp); \
1230-
if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \
1254+
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
12311255
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
12321256
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
12331257
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
12341258
CuCount); \
1235-
} else if (K_in * N_in <= 32 * 1024 * 1.2) { \
1259+
} else if (K_in * N_in <= max_lds_len * 1.2) { \
12361260
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
12371261
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
12381262
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
@@ -1272,7 +1296,7 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
12721296
return out_c;
12731297
}
12741298

1275-
#if defined(__HIP__MI300__) // TODO: Add NAVI support
1299+
#if defined(__HIP__MI3XX__) // TODO: Add NAVI support
12761300
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
12771301
int A_CHUNK, int UNRL, int N>
12781302
__global__ void __launch_bounds__(WvPrGrp* THRDS)
@@ -1281,6 +1305,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
12811305
const float* __restrict__ s_A,
12821306
const float* __restrict__ s_B, const int _WvPrGrp,
12831307
const int CuCount) {
1308+
constexpr int max_lds_len = LDS_SIZE;
12841309
using scalar8 =
12851310
__attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float;
12861311
using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int;
@@ -1296,10 +1321,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
12961321
scalar8 h8;
12971322
};
12981323

1299-
__shared__ fp8_t s[1024 * 64];
1324+
__shared__ fp8_t s[max_lds_len];
13001325

13011326
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
1302-
k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) {
1327+
k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
13031328
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
13041329
}
13051330
__syncthreads();
@@ -1436,7 +1461,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
14361461
m += CuCount * _WvPrGrp * YTILE;
14371462
}
14381463
}
1439-
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
1464+
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
14401465
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
14411466
int A_CHUNK, int UNRL, int N>
14421467
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
@@ -1446,16 +1471,17 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
14461471
const int _WvPrGrp, const int CuCount) {
14471472
UNREACHABLE_CODE
14481473
}
1449-
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
1474+
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
14501475

1451-
#if defined(__HIP__MI300__) // TODO: Add NAVI support
1476+
#if defined(__HIP__MI3XX__) // TODO: Add NAVI support
14521477
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
14531478
int A_CHUNK, int UNRL, int N>
14541479
__global__ void __launch_bounds__(WvPrGrp* THRDS)
14551480
wvSplitKQ_hf_(const int K, const int Kp, const int M, const fp8_t* B,
14561481
const fp8_t* __restrict__ A, scalar_t* C,
14571482
const float* __restrict__ s_A, const float* __restrict__ s_B,
14581483
const int _WvPrGrp, const int CuCount) {
1484+
constexpr int max_lds_len = LDS_SIZE;
14591485
using scalar8 =
14601486
__attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float;
14611487
using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int;
@@ -1471,10 +1497,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
14711497
scalar8 h8;
14721498
};
14731499

1474-
__shared__ fp8_t s[1024 * 64];
1500+
__shared__ fp8_t s[max_lds_len];
14751501

14761502
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
1477-
k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) {
1503+
k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
14781504
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
14791505
}
14801506
__syncthreads();
@@ -1517,7 +1543,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
15171543
uint32_t k_ = k + threadIdx.x * A_CHUNK;
15181544
if (k_ >= K) break;
15191545
for (int n = 0; n < N; n++) {
1520-
if (k_ + K * n < 64 * 1024)
1546+
if (k_ + K * n < max_lds_len)
15211547
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
15221548
else
15231549
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
@@ -1608,7 +1634,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
16081634
m += CuCount * _WvPrGrp * YTILE;
16091635
}
16101636
}
1611-
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
1637+
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
16121638
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
16131639
int A_CHUNK, int UNRL, int N>
16141640
__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
@@ -1618,7 +1644,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
16181644
const int CuCount) {
16191645
UNREACHABLE_CODE
16201646
}
1621-
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
1647+
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
16221648

16231649
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
16241650
at::Tensor& scale_a, at::Tensor& scale_b,
@@ -1638,12 +1664,13 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
16381664
dim3 grid(CuCount);
16391665
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
16401666
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
1667+
const int max_lds_len = get_lds_size();
16411668

16421669
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
16431670
_N) \
16441671
{ \
16451672
dim3 block(64, _WvPrGrp); \
1646-
if ((K_in * N_in <= 64 * 1024) && (M_in % _YTILEs == 0)) { \
1673+
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
16471674
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
16481675
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
16491676
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \

0 commit comments

Comments
 (0)