Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ggml/src/ggml-cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND)
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")

file(GLOB GGML_SOURCES_CUDA "*.cu")
file(GLOB SRCS "template-instances/fattn-tile*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/fattn-mma*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/mmq*.cu")
Expand Down
7 changes: 6 additions & 1 deletion ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ static bool fp16_available(const int cc) {
}

static bool fast_fp16_available(const int cc) {
return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
return GGML_CUDA_CC_IS_AMD(cc) ||
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610);
}

// To be used for feature selection of external libraries, e.g. cuBLAS.
Expand Down Expand Up @@ -571,6 +572,10 @@ static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v,
}

// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
// Important: do not use this function if dst and src both point at registers.
// Due to the strict aliasing rule the compiler can do incorrect optimizations if src and dst have different types.
// The function is intended for copies between registers and SRAM/VRAM to make the compiler emit the right instructions.
// If dst and src point at different address spaces then they are guaranteed to not be aliased.
template <int nbytes, int alignment = 0>
static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
if constexpr (alignment != 0) {
Expand Down
9 changes: 3 additions & 6 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -793,8 +793,6 @@ void launch_fattn(
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");

GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");

ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();
const int id = ggml_cuda_get_device();
Expand Down Expand Up @@ -878,7 +876,7 @@ void launch_fattn(
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
// multiple sequences of possibly different lengths.
if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
const int s31 = mask->nb[1] / sizeof(half2);
const int s33 = mask->nb[3] / sizeof(half2);

Expand Down Expand Up @@ -916,8 +914,7 @@ void launch_fattn(

dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
} else {
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.

// parallel_blocks must not be larger than what the tensor size allows:
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
Expand Down Expand Up @@ -946,7 +943,7 @@ void launch_fattn(

blocks_num.x = ntiles_x;
blocks_num.y = parallel_blocks;
blocks_num.z = Q->ne[2]*Q->ne[3];
blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];

if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
Expand Down
Loading
Loading