Skip to content

Commit b4955f0

Browse files
committed
implement "auto" mode for clip flash attn
1 parent 19116a4 commit b4955f0

File tree

6 files changed

+74
-10
lines changed

6 files changed

+74
-10
lines changed

tools/mtmd/clip.cpp

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
55
#include "clip.h"
66
#include "clip-impl.h"
7+
#include "mtmd.h"
78
#include "ggml.h"
89
#include "ggml-cpp.h"
910
#include "ggml-cpu.h"
@@ -427,12 +428,14 @@ struct clip_ctx {
427428

428429
int max_nodes = 8192;
429430
ggml_backend_sched_ptr sched;
431+
llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
430432

431433
// for debugging
432434
bool debug_graph = false;
433435
std::vector<ggml_tensor *> debug_print_tensors;
434436

435437
clip_ctx(clip_context_params & ctx_params) {
438+
flash_attn_type = ctx_params.flash_attn_type;
436439
debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr;
437440
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
438441
if (!backend_cpu) {
@@ -2261,16 +2264,36 @@ struct clip_graph {
22612264
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
22622265
//cb(k, "k", il);
22632266

2264-
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
2265-
//cb(k, "v", il);
2267+
ggml_tensor * cur;
22662268

2267-
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
2268-
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
2269+
if (ctx->flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) {
2270+
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
22692271

2270-
ggml_tensor * cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f);
2271-
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
2272+
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
2273+
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
22722274

2273-
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
2275+
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f);
2276+
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
2277+
2278+
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
2279+
2280+
} else {
2281+
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
2282+
v = ggml_cont(ctx0, v);
2283+
2284+
const auto n_tokens = q->ne[1];
2285+
const auto n_head = q->ne[2];
2286+
2287+
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
2288+
// F32 may not needed for vision encoders?
2289+
// ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
2290+
2291+
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f);
2292+
2293+
ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
2294+
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
2295+
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
2296+
}
22742297

22752298
cb(cur, "kqv_out", il);
22762299

@@ -3181,7 +3204,30 @@ struct clip_model_loader {
31813204
}
31823205
}
31833206

3184-
void alloc_compute_meta(clip_ctx & ctx_clip) {
3207+
void warmup(clip_ctx & ctx_clip) {
3208+
if (ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
3209+
// try to enable flash attention to see if it's supported
3210+
ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
3211+
bool supported = alloc_compute_meta(ctx_clip);
3212+
if (!supported) {
3213+
LOG_WRN("%s: flash attention not supported, memory usage will increase\n", __func__);
3214+
// TODO: maybe log more details about why flash attention is not supported
3215+
ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
3216+
alloc_compute_meta(ctx_clip);
3217+
}
3218+
} else {
3219+
bool supported = alloc_compute_meta(ctx_clip);
3220+
if (!supported) {
3221+
LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__);
3222+
}
3223+
}
3224+
3225+
LOG_INF("%s: flash attention is %s\n", __func__,
3226+
(ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
3227+
}
3228+
3229+
// return false if flash attention is not supported
3230+
bool alloc_compute_meta(clip_ctx & ctx_clip) {
31853231
const auto & hparams = ctx_clip.model.hparams;
31863232
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
31873233

@@ -3217,6 +3263,17 @@ struct clip_model_loader {
32173263
const int n_nodes = ggml_graph_n_nodes(gf);
32183264

32193265
LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__, n_splits, n_nodes);
3266+
3267+
// check flash attention support
3268+
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
3269+
ggml_tensor * node = ggml_graph_node(gf, i);
3270+
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
3271+
if (!ggml_backend_supports_op(ctx_clip.backend, node)) {
3272+
return false;
3273+
}
3274+
}
3275+
}
3276+
return true;
32203277
}
32213278

32223279
void get_bool(const std::string & key, bool & output, bool required = true) {
@@ -3306,14 +3363,14 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
33063363
ctx_vision = new clip_ctx(ctx_params);
33073364
loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION);
33083365
loader.load_tensors(*ctx_vision);
3309-
loader.alloc_compute_meta(*ctx_vision);
3366+
loader.warmup(*ctx_vision);
33103367
}
33113368

33123369
if (loader.has_audio) {
33133370
ctx_audio = new clip_ctx(ctx_params);
33143371
loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO);
33153372
loader.load_tensors(*ctx_audio);
3316-
loader.alloc_compute_meta(*ctx_audio);
3373+
loader.warmup(*ctx_audio);
33173374
}
33183375

33193376
} catch (const std::exception & e) {

tools/mtmd/clip.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "ggml.h"
4+
#include "mtmd.h"
45
#include <stddef.h>
56
#include <stdint.h>
67

@@ -25,6 +26,7 @@ enum clip_modality {
2526
struct clip_context_params {
2627
bool use_gpu;
2728
enum ggml_log_level verbosity;
29+
llama_flash_attn_type flash_attn_type;
2830
};
2931

3032
struct clip_init_result {

tools/mtmd/mtmd-cli.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ struct mtmd_cli_context {
136136
mparams.print_timings = true;
137137
mparams.n_threads = params.cpuparams.n_threads;
138138
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
139+
mparams.flash_attn_type = params.flash_attn_type;
139140
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
140141
if (!ctx_vision.get()) {
141142
LOG_ERR("Failed to load vision model from %s\n", clip_path);

tools/mtmd/mtmd.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ mtmd_context_params mtmd_context_params_default() {
100100
params.verbosity = GGML_LOG_LEVEL_INFO;
101101
params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
102102
params.media_marker = mtmd_default_marker();
103+
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
103104
return params;
104105
}
105106

@@ -164,6 +165,7 @@ struct mtmd_context {
164165
clip_context_params ctx_clip_params;
165166
ctx_clip_params.use_gpu = ctx_params.use_gpu;
166167
ctx_clip_params.verbosity = ctx_params.verbosity;
168+
ctx_clip_params.flash_attn_type = ctx_params.flash_attn_type;
167169
auto res = clip_init(mmproj_fname, ctx_clip_params);
168170
ctx_v = res.ctx_v;
169171
ctx_a = res.ctx_a;

tools/mtmd/mtmd.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ struct mtmd_context_params {
8282
enum ggml_log_level verbosity;
8383
const char * image_marker; // deprecated, use media_marker instead
8484
const char * media_marker;
85+
llama_flash_attn_type flash_attn_type;
8586
};
8687

8788
MTMD_API const char * mtmd_default_marker(void);

tools/server/server.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2456,6 +2456,7 @@ struct server_context {
24562456
mparams.print_timings = false;
24572457
mparams.n_threads = params_base.cpuparams.n_threads;
24582458
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
2459+
mparams.flash_attn_type = params_base.flash_attn_type;
24592460
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
24602461
if (mctx == nullptr) {
24612462
SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());

0 commit comments

Comments
 (0)