Skip to content

Commit 2d7a1f9

Browse files
committed
HIP: Add support for new gfx1200 and gfx1201 targets
1 parent f08f4b3 commit 2d7a1f9

File tree

6 files changed

+19
-13
lines changed

6 files changed

+19
-13
lines changed

docs/build.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ The following compilation options are also available to tweak performance:
189189

190190
| Option | Legal values | Default | Description |
191191
|-------------------------------|------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
192-
| GGML_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, RDNA3). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. |
192+
| GGML_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, RDNA3, RDNA4). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. |
193193
| GGML_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models |
194194
| GGML_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
195195
| GGML_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. |

ggml/src/ggml-cuda/common.cuh

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,24 +49,26 @@
4949
#define GGML_CUDA_CC_ADA_LOVELACE 890
5050
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
5151

52-
// GCN/CNDA, wave size is 64
52+
// GCN/CDNA, wave size is 64
5353
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
5454
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
5555
#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a
5656
#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
5757
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
5858
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300
5959

60-
// RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32
60+
// RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32
6161
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
6262
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
6363
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
64+
#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
6465

6566
#define GGML_CUDA_CC_IS_AMD(cc) (cc >= GGML_CUDA_CC_OFFSET_AMD)
6667
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
6768
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
6869
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
6970
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
71+
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
7072
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
7173
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
7274

@@ -197,9 +199,9 @@ typedef float2 dfloat2;
197199
#define FP16_MMA_AVAILABLE
198200
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
199201

200-
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3))
202+
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
201203
#define FP16_MMA_AVAILABLE
202-
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3))
204+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
203205

204206
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
205207
#define NEW_MMA_AVAILABLE
@@ -232,14 +234,14 @@ static bool fp16_mma_available(const int cc) {
232234
return false;
233235
#else
234236
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ||
235-
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
237+
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3 || cc >= GGML_CUDA_CC_RDNA4;
236238
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
237239
}
238240

239241
// To be used for feature selection of external libraries, e.g. cuBLAS.
240242
static bool fp16_mma_hardware_available(const int cc) {
241243
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA ||
242-
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
244+
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3 || cc >= GGML_CUDA_CC_RDNA4;
243245
}
244246

245247
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
@@ -397,7 +399,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
397399
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
398400
#if defined(CDNA) || defined(RDNA2) || defined(__gfx906__)
399401
c = __builtin_amdgcn_sdot4(a, b, c, false);
400-
#elif defined(RDNA3)
402+
#elif defined(RDNA3) || defined(RDNA4)
401403
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
402404
#elif defined(RDNA1) || defined(__gfx900__)
403405
int tmp1;

ggml/src/ggml-cuda/mmq.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
149149
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
150150
}
151151

152-
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
152+
return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
153153
}

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2577,9 +2577,9 @@ static __device__ void mul_mat_q_process_tile(
25772577

25782578
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
25792579
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
2580-
#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
2580+
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
25812581
__launch_bounds__(WARP_SIZE*nwarps, 2)
2582-
#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
2582+
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
25832583
#else
25842584
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
25852585
__launch_bounds__(WARP_SIZE*nwarps, 1)

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ enum mmvq_parameter_table_id {
5454
};
5555

5656
static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
57-
#if defined(RDNA2) || defined(RDNA3)
57+
#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4)
5858
return MMVQ_PARAMETERS_RDNA2;
5959
#elif defined(GCN) || defined(CDNA)
6060
return MMVQ_PARAMETERS_GCN;
@@ -64,7 +64,7 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
6464
}
6565

6666
static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
67-
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
67+
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
6868
return MMVQ_PARAMETERS_RDNA2;
6969
}
7070
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,10 @@
150150
#define CDNA
151151
#endif
152152

153+
#if defined(__gfx1200__) || defined(__gfx1201__)
154+
#define RDNA4
155+
#endif
156+
153157
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
154158
defined(__gfx1150__) || defined(__gfx1151__)
155159
#define RDNA3

0 commit comments

Comments
 (0)