Skip to content

Commit 4ab0ad5

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents 0c561e6 + 95ce098 commit 4ab0ad5

File tree

13 files changed

+92
-66
lines changed

13 files changed

+92
-66
lines changed

.devops/rocm.Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ FROM ${BASE_ROCM_DEV_CONTAINER} AS build
1616
# gfx803, gfx900, gfx906, gfx1032, gfx1101, gfx1102,not officialy supported
1717
# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.4.1/reference/system-requirements.html
1818

19-
ARG ROCM_DOCKER_ARCH='gfx803;gfx900;gfx906;gfx1010;gfx1030;gfx1032;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx1151'
19+
ARG ROCM_DOCKER_ARCH='gfx803;gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1010;gfx1030;gfx1032;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx1151'
2020
#ARG ROCM_DOCKER_ARCH='gfx1151'
2121

2222
# Set ROCm architectures

CODEOWNERS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@
5959
/ggml/src/ggml-cuda/mmq.* @JohannesGaessler
6060
/ggml/src/ggml-cuda/mmvf.* @JohannesGaessler
6161
/ggml/src/ggml-cuda/mmvq.* @JohannesGaessler
62+
/ggml/src/ggml-cuda/fattn-wmma* @IMbackK
63+
/ggml/src/ggml-hip/ @IMbackK
64+
/ggml/src/ggml-cuda/vendors/hip.h @IMbackK
6265
/ggml/src/ggml-impl.h @ggerganov @slaren
6366
/ggml/src/ggml-metal/ @ggerganov
6467
/ggml/src/ggml-opencl/ @lhez @max-krasnyansky

convert_hf_to_gguf.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4250,7 +4250,8 @@ def set_gguf_parameters(self):
42504250
# This logic matches modeling_plamo.py's is_mamba function
42514251
mamba_step = hparams.get("mamba_step", 2)
42524252
mamba_enabled = hparams.get("mamba_enabled", True)
4253-
mamba_layers = []
4253+
num_key_value_heads = []
4254+
num_attention_heads = []
42544255

42554256
if mamba_enabled:
42564257
for i in range(block_count):
@@ -4260,17 +4261,21 @@ def set_gguf_parameters(self):
42604261
else:
42614262
is_mamba = (i % mamba_step) != (mamba_step // 2)
42624263
if is_mamba:
4263-
mamba_layers.append(0)
4264+
num_key_value_heads.append(0)
4265+
num_attention_heads.append(0)
42644266
else:
4265-
mamba_layers.append(hparams.get("num_key_value_heads", 4))
4267+
num_key_value_heads.append(hparams.get("num_key_value_heads", 4))
4268+
num_attention_heads.append(hparams.get("num_attention_heads", 32))
42664269

4267-
if mamba_layers:
4268-
self.gguf_writer.add_head_count_kv(mamba_layers)
4270+
if num_key_value_heads and num_attention_heads:
4271+
self.gguf_writer.add_head_count_kv(num_key_value_heads)
4272+
self.gguf_writer.add_head_count(num_attention_heads)
42694273

42704274
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 2048))
42714275
self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096))
4276+
self.gguf_writer.add_key_length(hparams.get("hidden_size_per_head", 128))
4277+
self.gguf_writer.add_value_length(hparams.get("hidden_size_per_head", 128))
42724278
self.gguf_writer.add_block_count(block_count)
4273-
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32))
42744279
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
42754280
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000))
42764281

ggml/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,6 @@ option(GGML_HIP "ggml: use HIP"
209209
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
210210
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
211211
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
212-
option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 "ggml: enable rocWMMA FlashAttention on GFX12" OFF)
213212
option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON)
214213
option(GGML_HIP_EXPORT_METRICS "ggml: enable kernel perf metrics output" OFF)
215214
option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF)

ggml/src/ggml-cuda/common.cuh

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -220,14 +220,6 @@ static const char * cu_get_error_str(CUresult err) {
220220
#define FAST_FP16_AVAILABLE
221221
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
222222

223-
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
224-
#define FP16_MMA_AVAILABLE
225-
#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
226-
227-
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
228-
#define FP16_MMA_AVAILABLE
229-
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
230-
231223
#if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
232224
#define AMD_MFMA_AVAILABLE
233225
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
@@ -262,27 +254,6 @@ static bool fast_fp16_hardware_available(const int cc) {
262254
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
263255
}
264256

265-
// Any FP16 tensor core instructions are available for ggml code.
266-
static bool fp16_mma_available(const int cc) {
267-
#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
268-
return false;
269-
#else
270-
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
271-
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) ||
272-
GGML_CUDA_CC_IS_MTHREADS(cc)) {
273-
return true;
274-
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
275-
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
276-
return true;
277-
#else
278-
return false;
279-
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
280-
} else {
281-
return false;
282-
}
283-
#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
284-
}
285-
286257
// To be used for feature selection of external libraries, e.g. cuBLAS.
287258
static bool fp16_mma_hardware_available(const int cc) {
288259
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||

ggml/src/ggml-cuda/fattn-tile.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "common.cuh"
22
#include "fattn-common.cuh"
33
#include "fattn-tile.cuh"
4+
#include "fattn-wmma-f16.cuh"
45

56
// kq_stride == number of KQ rows to process per iteration
67
// kq_nbatch == number of K columns to load in parallel for KQ calculation
@@ -190,10 +191,10 @@ static __global__ void flash_attn_tile(
190191
#ifdef FLASH_ATTN_AVAILABLE
191192

192193
// Skip unused kernel variants for faster compilation:
193-
#ifdef FP16_MMA_AVAILABLE
194+
#ifdef GGML_USE_WMMA_FATTN
194195
NO_DEVICE_CODE;
195196
return;
196-
#endif // FP16_MMA_AVAILABLE
197+
#endif // GGML_USE_WMMA_FATTN
197198

198199
if (use_logit_softcap && !(D == 128 || D == 256)) {
199200
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,19 @@
66
#include "fattn-common.cuh"
77
#include "fattn-wmma-f16.cuh"
88

9-
#ifdef FP16_MMA_AVAILABLE
9+
#ifdef GGML_USE_WMMA_FATTN
1010
#if !defined(GGML_USE_HIP)
1111
#include <mma.h>
12-
#ifdef GGML_USE_MUSA
12+
#if defined(GGML_USE_MUSA)
1313
namespace wmma = mtmusa::wmma;
1414
#else // GGML_USE_MUSA
1515
namespace wmma = nvcuda::wmma;
1616
#endif // GGML_USE_MUSA
17-
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
17+
#elif defined(GGML_USE_HIP)
1818
#include <rocwmma/rocwmma.hpp>
1919
namespace wmma = rocwmma;
2020
#endif // !defined(GGML_USE_HIP)
21-
#endif // FP16_MMA_AVAILABLE
21+
#endif // GGML_USE_WMMA_FATTN
2222

2323
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
2424
template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
@@ -45,7 +45,7 @@ static __global__ void flash_attn_ext_f16(
4545
const int32_t nb21, const int32_t nb22, const int64_t nb23,
4646
const int32_t ne31, const int32_t ne32, const int32_t ne33,
4747
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
48-
#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
48+
#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
4949
// Skip unused kernel variants for faster compilation:
5050
if (use_logit_softcap && !(D == 128 || D == 256)) {
5151
NO_DEVICE_CODE;
@@ -481,7 +481,7 @@ static __global__ void flash_attn_ext_f16(
481481
ne31, ne32, ne33,
482482
nb31, nb32, nb33);
483483
NO_DEVICE_CODE;
484-
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
484+
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
485485
}
486486

487487
constexpr int get_max_power_of_2(int x) {
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,49 @@
11
#include "common.cuh"
22

3+
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
4+
#define GGML_USE_WMMA_FATTN
5+
#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
6+
7+
#if defined(GGML_HIP_ROCWMMA_FATTN)
8+
#if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
9+
#define GGML_USE_WMMA_FATTN
10+
#elif defined(CDNA)
11+
#warning "rocwmma fattn on CDNA is broken on rocwmma v2.0.0, expect degraded performance"
12+
#endif // defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
13+
#if defined(RDNA3)
14+
#define GGML_USE_WMMA_FATTN
15+
#endif // defined(RDNA3)
16+
#if defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
17+
#define GGML_USE_WMMA_FATTN
18+
#elif defined(RDNA4)
19+
#warning "rocwmma fattn is not suported on RDNA4 on rocwmma < v2.0.0, expect degraded performance"
20+
#endif // defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
21+
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
22+
23+
// WMMA flash attention requires FP16 matrix instructions to be available for ggml code.
24+
static bool ggml_cuda_should_use_wmma_fattn(const int cc) {
25+
#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
26+
return false;
27+
#else
28+
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA) ||
29+
GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
30+
return true;
31+
} else if (GGML_CUDA_CC_IS_CDNA(cc)){
32+
#if defined(GGML_HIP_ROCWMMA_FATTN) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
33+
return true;
34+
#else
35+
return false;
36+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
37+
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
38+
#if defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1
39+
return true;
40+
#else
41+
return false;
42+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1
43+
} else {
44+
return false;
45+
}
46+
#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
47+
}
48+
349
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-cuda/fattn.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
222222
if (V->ne[0] != K->ne[0]) {
223223
return BEST_FATTN_KERNEL_NONE;
224224
}
225-
if (!fp16_mma_available(cc) && !turing_mma_available(cc)) {
225+
if (!ggml_cuda_should_use_wmma_fattn(cc) && !turing_mma_available(cc)) {
226226
return BEST_FATTN_KERNEL_NONE;
227227
}
228228
break;
@@ -300,7 +300,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
300300
}
301301

302302
// For large batch sizes, use the WMMA kernel if possible:
303-
if (fp16_mma_available(cc)) {
303+
if (ggml_cuda_should_use_wmma_fattn(cc)) {
304304
return BEST_FATTN_KERNEL_WMMA_F16;
305305
}
306306

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
#include <hip/hip_fp16.h>
77
#include <hip/hip_bf16.h>
88

9+
#if defined(GGML_HIP_ROCWMMA_FATTN)
10+
#include <rocwmma/rocwmma-version.hpp>
11+
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
12+
913
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
1014
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
1115
#define CUBLAS_OP_N HIPBLAS_OP_N

0 commit comments

Comments
 (0)