Skip to content

Commit 59f1b15

Browse files
committed
add gfx950 support for skinny gemms
Signed-off-by: charlifu <[email protected]>
1 parent 05a4324 commit 59f1b15

File tree

5 files changed

+87
-58
lines changed

5 files changed

+87
-58
lines changed

csrc/rocm/skinny_gemms.cu

Lines changed: 65 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,29 @@
1313
#include "dispatch_utils.h"
1414
#include "quantization/fp8/common.cuh"
1515

16-
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__))
17-
#define __HIP__MI300_MI250__
16+
#if defined(__HIPCC__) && \
17+
(defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
18+
#define __HIP__GFX9__
1819
#endif
1920

20-
#if defined(__HIPCC__) && defined(__gfx942__)
21-
#define __HIP__MI300__
21+
#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__))
22+
#define __HIP__MI3XX__
2223
#endif
2324

25+
#if defined(__gfx950__)
26+
#define LDS_SIZE 160 * 1024
27+
#else
28+
#define LDS_SIZE 64 * 1024
29+
#endif
30+
31+
static int get_lds_size() {
32+
auto dprops = at::cuda::getCurrentDeviceProperties();
33+
std::string device_arch = dprops->gcnArchName;
34+
size_t substring = device_arch.find("gfx95");
35+
if (substring == std::string::npos) return 64 * 1024;
36+
return 160 * 1024;
37+
}
38+
2439
#if defined(NDEBUG)
2540
#undef NDEBUG
2641
#include <assert.h>
@@ -267,15 +282,16 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
267282
V0 += (s.x + s.y); \
268283
}
269284

270-
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
285+
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
271286
// This version targets cases where A[] fits LDS capacity
272287
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
273288
int UNRL, int N>
274289
__global__ void __launch_bounds__(WvPrGrp* THRDS)
275290
wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
276291
const scalar_t* __restrict__ A, scalar_t* C,
277292
const int _WvPrGrp, const int CuCount) {
278-
#if defined(__HIP__MI300__)
293+
constexpr int max_lds_len = LDS_SIZE / 2;
294+
#if defined(__HIP__MI3XX__)
279295
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
280296
#else
281297
constexpr bool use_mfma = false;
@@ -295,13 +311,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
295311
};
296312

297313
//----------------------------------------------------
298-
// Reserving 64 KB of LDS to have 1 WG / CU
314+
// Reserving 64/160 KB of LDS to have 1 WG / CU
299315
// Goal is to bring the activation matrix A to the LDS
300316
// and use it across the lifetime of the work group
301317
// TODO: When activation matrix is larger than 64 KB
302318
// then this is not goint to work!
303319
//----------------------------------------------------
304-
__shared__ scalar_t s[1024 * 32];
320+
__shared__ scalar_t s[max_lds_len];
305321

306322
//----------------------------------------------------
307323
// Fetch the activation matrix to LDS
@@ -312,11 +328,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
312328
// - Then the WG will move to another 8 K elements
313329
// TODO: Logic below will only work when K is multiple of 8
314330
//----------------------------------------------------
315-
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
331+
for (uint32_t k = 0; k < min(K * N, max_lds_len);
316332
k += THRDS * WvPrGrp * A_CHUNK) {
317333
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
318334

319-
if (k_in >= min(K * N, 32 * 1024)) break;
335+
if (k_in >= min(K * N, max_lds_len)) break;
320336

321337
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
322338
}
@@ -517,25 +533,26 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
517533
m += CuCount * _WvPrGrp * YTILE;
518534
}
519535
}
520-
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
536+
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
521537
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
522538
int UNRL, int N>
523539
__global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
524540
const scalar_t* __restrict__ A, scalar_t* C,
525541
const int _WvPrGrp, const int CuCount) {
526542
UNREACHABLE_CODE
527543
}
528-
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
544+
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
529545

530-
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
546+
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
531547
// This version targets cases where A[] marginally exceeds LDS capacity
532548
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
533549
int UNRL, int N>
534550
__global__ void __launch_bounds__(WvPrGrp* THRDS)
535551
wvSplitK_hf_(const int K, const int M, const scalar_t* B,
536552
const scalar_t* __restrict__ A, scalar_t* C,
537553
const int _WvPrGrp, const int CuCount) {
538-
#if defined(__HIP__MI300__)
554+
constexpr int max_lds_len = LDS_SIZE / 2;
555+
#if defined(__HIP__MI3XX__)
539556
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
540557
#else
541558
constexpr bool use_mfma = false;
@@ -561,7 +578,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
561578
// TODO: When activation matrix is larger than 64 KB
562579
// then this is not goint to work!
563580
//----------------------------------------------------
564-
__shared__ scalar_t s[1024 * 32];
581+
__shared__ scalar_t s[max_lds_len];
565582

566583
//----------------------------------------------------
567584
// Computation of columns that need to be committed to memory!
@@ -598,11 +615,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
598615
// - Then the WG will move to another 8 K elements
599616
// TODO: Logic below will only work when K is multiple of 8
600617
//----------------------------------------------------
601-
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
618+
for (uint32_t k = 0; k < min(K * N, max_lds_len);
602619
k += THRDS * WvPrGrp * A_CHUNK) {
603620
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
604621

605-
if (k_in >= min(K * N, 32 * 1024)) break;
622+
if (k_in >= min(K * N, max_lds_len)) break;
606623

607624
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
608625
}
@@ -686,7 +703,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
686703
// Fetch A activation matrix in interleaved fashion from LDS or memory
687704

688705
for (int n = 0; n < N; n++) {
689-
if (k_ + K * n < 32 * 1024)
706+
if (k_ + K * n < max_lds_len)
690707
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
691708
else
692709
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
@@ -817,25 +834,26 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
817834
}
818835
}
819836

820-
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
837+
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
821838
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
822839
int UNRL, int N>
823840
__global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
824841
const scalar_t* __restrict__ A, scalar_t* C,
825842
const int _WvPrGrp, const int CuCount) {
826843
UNREACHABLE_CODE
827844
}
828-
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
845+
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
829846

830-
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
847+
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
831848
// This version targets big A[] cases, where it is much larger than LDS capacity
832849
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
833850
int UNRL, int N>
834851
__global__ void __launch_bounds__(WvPrGrp* THRDS)
835852
wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
836853
const scalar_t* __restrict__ A, scalar_t* C,
837854
const int _WvPrGrp, const int CuCount) {
838-
#if defined(__HIP__MI300__)
855+
constexpr int max_lds_len = LDS_SIZE / 2;
856+
#if defined(__HIP__MI3XX__)
839857
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
840858
#else
841859
constexpr bool use_mfma = false;
@@ -855,13 +873,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
855873
};
856874

857875
//----------------------------------------------------
858-
// Reserving 64 KB of LDS to have 1 WG / CU
876+
// Reserving 64/160 KB of LDS to have 1 WG / CU
859877
// Goal is to bring the activation matrix A to the LDS
860878
// and use it across the lifetime of the work group
861879
// TODO: When activation matrix is larger than 64 KB
862880
// then this is not goint to work!
863881
//----------------------------------------------------
864-
__shared__ scalar_t s[1024 * 32];
882+
__shared__ scalar_t s[max_lds_len];
865883

866884
//----------------------------------------------------
867885
// Computation of columns that need to be committed to memory!
@@ -902,11 +920,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
902920
//----------------------------------------------------
903921
#define PCML
904922
#ifndef PCML
905-
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
923+
for (uint32_t k = 0; k < min(K * N, max_lds_len);
906924
k += THRDS * WvPrGrp * A_CHUNK) {
907925
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
908926

909-
if (k_in >= min(K * N, 32 * 1024)) break;
927+
if (k_in >= min(K * N, max_lds_len)) break;
910928

911929
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
912930
}
@@ -916,7 +934,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
916934
#define TUC (THRDS * UNRL * A_CHUNK)
917935
uint32_t kBase = 0;
918936
// find biggest k size that fits in LDS
919-
uint32_t kFit = (32 * 1024) / N;
937+
uint32_t kFit = (max_lds_len) / N;
920938
// kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple
921939
// of TUC
922940
kFit = (kFit % TUC == 0)
@@ -1164,15 +1182,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
11641182
}
11651183
}
11661184
}
1167-
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
1185+
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
11681186
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
11691187
int UNRL, int N>
11701188
__global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
11711189
const scalar_t* __restrict__ A, scalar_t* C,
11721190
const int _WvPrGrp, const int CuCount) {
11731191
UNREACHABLE_CODE
11741192
}
1175-
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
1193+
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
11761194

11771195
int mindiv(int N, int div1, int div2) {
11781196
int nPrRnd = div1 * div2;
@@ -1222,17 +1240,18 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
12221240

12231241
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
12241242
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
1243+
const int max_lds_len = get_lds_size() / 2;
12251244

12261245
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
12271246
_N) \
12281247
{ \
12291248
dim3 block(64, _WvPrGrp); \
1230-
if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \
1249+
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
12311250
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
12321251
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
12331252
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
12341253
CuCount); \
1235-
} else if (K_in * N_in <= 32 * 1024 * 1.2) { \
1254+
} else if (K_in * N_in <= max_lds_len * 1.2) { \
12361255
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
12371256
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
12381257
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
@@ -1272,7 +1291,7 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
12721291
return out_c;
12731292
}
12741293

1275-
#if defined(__HIP__MI300__) // TODO: Add NAVI support
1294+
#if defined(__HIP__MI3XX__) // TODO: Add NAVI support
12761295
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
12771296
int A_CHUNK, int UNRL, int N>
12781297
__global__ void __launch_bounds__(WvPrGrp* THRDS)
@@ -1281,6 +1300,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
12811300
const float* __restrict__ s_A,
12821301
const float* __restrict__ s_B, const int _WvPrGrp,
12831302
const int CuCount) {
1303+
constexpr int max_lds_len = LDS_SIZE;
12841304
using scalar8 =
12851305
__attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float;
12861306
using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int;
@@ -1296,10 +1316,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
12961316
scalar8 h8;
12971317
};
12981318

1299-
__shared__ fp8_t s[1024 * 64];
1319+
__shared__ fp8_t s[max_lds_len];
13001320

13011321
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
1302-
k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) {
1322+
k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
13031323
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
13041324
}
13051325
__syncthreads();
@@ -1436,7 +1456,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
14361456
m += CuCount * _WvPrGrp * YTILE;
14371457
}
14381458
}
1439-
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
1459+
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
14401460
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
14411461
int A_CHUNK, int UNRL, int N>
14421462
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
@@ -1446,16 +1466,17 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
14461466
const int _WvPrGrp, const int CuCount) {
14471467
UNREACHABLE_CODE
14481468
}
1449-
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
1469+
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
14501470

1451-
#if defined(__HIP__MI300__) // TODO: Add NAVI support
1471+
#if defined(__HIP__MI3XX__) // TODO: Add NAVI support
14521472
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
14531473
int A_CHUNK, int UNRL, int N>
14541474
__global__ void __launch_bounds__(WvPrGrp* THRDS)
14551475
wvSplitKQ_hf_(const int K, const int Kp, const int M, const fp8_t* B,
14561476
const fp8_t* __restrict__ A, scalar_t* C,
14571477
const float* __restrict__ s_A, const float* __restrict__ s_B,
14581478
const int _WvPrGrp, const int CuCount) {
1479+
constexpr int max_lds_len = LDS_SIZE;
14591480
using scalar8 =
14601481
__attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float;
14611482
using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int;
@@ -1471,10 +1492,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
14711492
scalar8 h8;
14721493
};
14731494

1474-
__shared__ fp8_t s[1024 * 64];
1495+
__shared__ fp8_t s[max_lds_len];
14751496

14761497
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
1477-
k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) {
1498+
k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
14781499
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
14791500
}
14801501
__syncthreads();
@@ -1517,7 +1538,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
15171538
uint32_t k_ = k + threadIdx.x * A_CHUNK;
15181539
if (k_ >= K) break;
15191540
for (int n = 0; n < N; n++) {
1520-
if (k_ + K * n < 64 * 1024)
1541+
if (k_ + K * n < max_lds_len)
15211542
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
15221543
else
15231544
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
@@ -1608,7 +1629,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
16081629
m += CuCount * _WvPrGrp * YTILE;
16091630
}
16101631
}
1611-
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
1632+
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
16121633
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
16131634
int A_CHUNK, int UNRL, int N>
16141635
__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
@@ -1618,7 +1639,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
16181639
const int CuCount) {
16191640
UNREACHABLE_CODE
16201641
}
1621-
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
1642+
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
16221643

16231644
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
16241645
at::Tensor& scale_a, at::Tensor& scale_b,
@@ -1638,12 +1659,13 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
16381659
dim3 grid(CuCount);
16391660
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
16401661
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
1662+
const int max_lds_len = get_lds_size();
16411663

16421664
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
16431665
_N) \
16441666
{ \
16451667
dim3 block(64, _WvPrGrp); \
1646-
if ((K_in * N_in <= 64 * 1024) && (M_in % _YTILEs == 0)) { \
1668+
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
16471669
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
16481670
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
16491671
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \

0 commit comments

Comments
 (0)