Skip to content

Commit 4d1fde0

Browse files
committed
Cut flash attention from CUDA again
If you really want it anyway, just say: ./llamafile -ngl 999 -fa --recompile ... And it'll build ggml-cuda with flash attention for your system.
1 parent 629e208 commit 4d1fde0

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

llama.cpp/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
901901
}
902902
if (arg == "-fa" || arg == "--flash-attn") {
903903
params.flash_attn = true;
904+
FLAG_flash_attn = true; // [jart]
904905
return true;
905906
}
906907
if (arg == "--color") {

llama.cpp/ggml-cuda.cu

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
#include <string>
2424
#include <vector>
2525

26-
2726
////////////////////////////////////////////////////////////////////////////////
2827
//
2928
// ROLLUP acc.cu
@@ -3608,6 +3607,8 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
36083607
GGML_UNUSED(src1_padded_row_size);
36093608
}
36103609

3610+
#ifndef GGML_MINIMIZE_CODE_SIZE
3611+
36113612
////////////////////////////////////////////////////////////////////////////////
36123613
//
36133614
// ROLLUP fattn.cu
@@ -5098,7 +5099,6 @@ void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_ten
50985099
//
50995100
////////////////////////////////////////////////////////////////////////////////
51005101

5101-
51025102
template<int D, int ncols, int parallel_blocks> // D == head size
51035103
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
51045104
__launch_bounds__(D, 1)
@@ -5432,7 +5432,6 @@ void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, gg
54325432
//
54335433
////////////////////////////////////////////////////////////////////////////////
54345434

5435-
54365435
template<int D, int ncols, int parallel_blocks> // D == head size
54375436
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
54385437
__launch_bounds__(D, 1)
@@ -5709,6 +5708,8 @@ void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tens
57095708
launch_fattn_vec_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
57105709
}
57115710

5711+
#endif // GGML_MINIMIZE_CODE_SIZE
5712+
57125713
////////////////////////////////////////////////////////////////////////////////
57135714
//
57145715
// ROLLUP getrows.cu
@@ -13096,7 +13097,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
1309613097
ggml_cuda_op_argsort(ctx, dst);
1309713098
break;
1309813099
case GGML_OP_FLASH_ATTN_EXT:
13100+
#ifndef GGML_MINIMIZE_CODE_SIZE
1309913101
ggml_cuda_flash_attn_ext(ctx, dst);
13102+
#endif
1310013103
break;
1310113104
default:
1310213105
return false;
@@ -13649,7 +13652,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
1364913652
case GGML_OP_LEAKY_RELU:
1365013653
return true;
1365113654
case GGML_OP_FLASH_ATTN_EXT:
13652-
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
13655+
#if defined(GGML_MINIMIZE_CODE_SIZE)
13656+
return false;
13657+
#elif defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
1365313658
return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128;
1365413659
#else
1365513660
if (op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128) {

llamafile/cuda.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ __static_yoink("llama.cpp/ggml-backend-impl.h");
6161
/* "-DNDEBUG", */ "-DGGML_BUILD=1", "-DGGML_SHARED=1", "-DGGML_MULTIPLATFORM", \
6262
"-DGGML_CUDA_DMMV_X=32", "-DK_QUANTS_PER_ITERATION=2", \
6363
"-DGGML_CUDA_PEER_MAX_BATCH_SIZE=128", "-DGGML_CUDA_MMV_Y=1", \
64-
(FLAG_tinyblas ? "-DGGML_USE_TINYBLAS" : "-DGGML_USE_CUBLAS")
64+
(FLAG_tinyblas ? "-DGGML_USE_TINYBLAS" : "-DGGML_USE_CUBLAS"), \
65+
(FLAG_flash_attn ? "-DTEHFLASH" : "-DGGML_MINIMIZE_CODE_SIZE")
6566

6667
#define NVCC_FLAGS \
6768
"-std=c++11", "-O3", "--shared", "--use_fast_math", "--forward-unknown-to-host-compiler", \
@@ -567,6 +568,7 @@ static bool compile_amd_windows(const char *clangxx, const char *dso, const char
567568
"-DGGML_CUDA_PEER_MAX_BATCH_SIZE=128",
568569
"-DGGML_CUDA_MMV_Y=1",
569570
"-DGGML_USE_TINYBLAS",
571+
FLAG_flash_attn ? "-DTEHFLASH" : "-DGGML_MINIMIZE_CODE_SIZE",
570572
"-o",
571573
(char *)tmpdso,
572574
(char *)src,

0 commit comments

Comments
 (0)