Skip to content

Commit 11f0af5

Browse files
CUDA: faster tile FA, add oob checks, more HSs (ggml-org#16492)
1 parent a3cb047 commit 11f0af5

18 files changed

+1358
-784
lines changed

ggml/src/ggml-cuda/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND)
4444
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
4545

4646
file(GLOB GGML_SOURCES_CUDA "*.cu")
47+
file(GLOB SRCS "template-instances/fattn-tile*.cu")
48+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
4749
file(GLOB SRCS "template-instances/fattn-mma*.cu")
4850
list(APPEND GGML_SOURCES_CUDA ${SRCS})
4951
file(GLOB SRCS "template-instances/mmq*.cu")

ggml/src/ggml-cuda/common.cuh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,8 @@ static bool fp16_available(const int cc) {
245245
}
246246

247247
static bool fast_fp16_available(const int cc) {
248-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
248+
return GGML_CUDA_CC_IS_AMD(cc) ||
249+
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610);
249250
}
250251

251252
// To be used for feature selection of external libraries, e.g. cuBLAS.
@@ -571,6 +572,10 @@ static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v,
571572
}
572573

573574
// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
575+
// Important: do not use this function if dst and src both point at registers.
576+
// Due to the strict aliasing rule the compiler can do incorrect optimizations if src and dst have different types.
577+
// The function is intended for copies between registers and SRAM/VRAM to make the compiler emit the right instructions.
578+
// If dst and src point at different address spaces then they are guaranteed to not be aliased.
574579
template <int nbytes, int alignment = 0>
575580
static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
576581
if constexpr (alignment != 0) {

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -793,8 +793,6 @@ void launch_fattn(
793793
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
794794
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
795795

796-
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
797-
798796
ggml_cuda_pool & pool = ctx.pool();
799797
cudaStream_t main_stream = ctx.stream();
800798
const int id = ggml_cuda_get_device();
@@ -878,7 +876,7 @@ void launch_fattn(
878876
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
879877
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
880878
// multiple sequences of possibly different lengths.
881-
if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
879+
if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
882880
const int s31 = mask->nb[1] / sizeof(half2);
883881
const int s33 = mask->nb[3] / sizeof(half2);
884882

@@ -916,8 +914,7 @@ void launch_fattn(
916914

917915
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
918916
} else {
919-
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
920-
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
917+
const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
921918

922919
// parallel_blocks must not be larger than what the tensor size allows:
923920
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
@@ -946,7 +943,7 @@ void launch_fattn(
946943

947944
blocks_num.x = ntiles_x;
948945
blocks_num.y = parallel_blocks;
949-
blocks_num.z = Q->ne[2]*Q->ne[3];
946+
blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];
950947

951948
if (parallel_blocks > 1) {
952949
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));

0 commit comments

Comments
 (0)