-
Notifications
You must be signed in to change notification settings - Fork 13.5k
clip : use FA #16837
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
clip : use FA #16837
Changes from 4 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
3aa835b
clip : use FA
ggerganov a4b54f2
cont : add warning about unsupported ops
ggerganov 19116a4
Merge branch 'master' into gg/clip-fa
ngxson b4955f0
implement "auto" mode for clip flash attn
ngxson bdb43f6
clip : print more detailed op support info during warmup
ggerganov 29330dc
cont : remove obsolete comment [no ci]
ggerganov b67a168
improve debugging message
ngxson cdb3dea
trailing space
ngxson d441c31
metal : remove stray return
ggerganov File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| // Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch | ||
| #include "clip.h" | ||
| #include "clip-impl.h" | ||
| #include "mtmd.h" | ||
| #include "ggml.h" | ||
| #include "ggml-cpp.h" | ||
| #include "ggml-cpu.h" | ||
|
|
@@ -28,6 +29,7 @@ | |
| #include <numeric> | ||
| #include <functional> | ||
|
|
||
| // TODO: allow to pass callback from user code | ||
| struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL}; | ||
|
|
||
| enum ffn_op_type { | ||
|
|
@@ -426,12 +428,14 @@ struct clip_ctx { | |
|
|
||
| int max_nodes = 8192; | ||
| ggml_backend_sched_ptr sched; | ||
| llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; | ||
|
|
||
| // for debugging | ||
| bool debug_graph = false; | ||
| std::vector<ggml_tensor *> debug_print_tensors; | ||
|
|
||
| clip_ctx(clip_context_params & ctx_params) { | ||
| flash_attn_type = ctx_params.flash_attn_type; | ||
| debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr; | ||
| backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); | ||
| if (!backend_cpu) { | ||
|
|
@@ -2260,17 +2264,25 @@ struct clip_graph { | |
| ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3); | ||
| //cb(k, "k", il); | ||
|
|
||
| ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3); | ||
| v = ggml_cont(ctx0, v); | ||
| //cb(k, "v", il); | ||
|
|
||
| ggml_tensor * cur; | ||
|
|
||
| // TODO @ngxson : support flash attention | ||
| { | ||
| if (ctx->flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) { | ||
| ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); | ||
|
|
||
| k = ggml_cast(ctx0, k, GGML_TYPE_F16); | ||
| v = ggml_cast(ctx0, v, GGML_TYPE_F16); | ||
|
|
||
| cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f); | ||
| ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); | ||
|
|
||
| cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); | ||
|
|
||
| } else { | ||
| ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3); | ||
| v = ggml_cont(ctx0, v); | ||
|
|
||
| const auto n_tokens = q->ne[1]; | ||
| const auto n_head = q->ne[2]; | ||
| // const auto n_kv = k->ne[1]; // for flash attention | ||
|
|
||
| ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); | ||
| // F32 may not needed for vision encoders? | ||
|
|
@@ -3192,7 +3204,30 @@ struct clip_model_loader { | |
| } | ||
| } | ||
|
|
||
| void alloc_compute_meta(clip_ctx & ctx_clip) { | ||
| void warmup(clip_ctx & ctx_clip) { | ||
| if (ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { | ||
| // try to enable flash attention to see if it's supported | ||
| ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED; | ||
| bool supported = alloc_compute_meta(ctx_clip); | ||
| if (!supported) { | ||
| LOG_WRN("%s: flash attention not supported, memory usage will increase\n", __func__); | ||
| // TODO: maybe log more details about why flash attention is not supported | ||
| ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; | ||
| alloc_compute_meta(ctx_clip); | ||
| } | ||
| } else { | ||
| bool supported = alloc_compute_meta(ctx_clip); | ||
| if (!supported) { | ||
| LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__); | ||
| } | ||
| } | ||
|
|
||
| LOG_INF("%s: flash attention is %s\n", __func__, | ||
| (ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled"); | ||
| } | ||
|
|
||
| // return false if flash attention is not supported | ||
| bool alloc_compute_meta(clip_ctx & ctx_clip) { | ||
| const auto & hparams = ctx_clip.model.hparams; | ||
| ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead()); | ||
|
|
||
|
|
@@ -3223,6 +3258,22 @@ struct clip_model_loader { | |
| size / 1024.0 / 1024.0); | ||
| } | ||
| } | ||
|
|
||
| const int n_splits = ggml_backend_sched_get_n_splits(ctx_clip.sched.get()); | ||
| const int n_nodes = ggml_graph_n_nodes(gf); | ||
|
|
||
| LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__, n_splits, n_nodes); | ||
|
|
||
| // check flash attention support | ||
| for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { | ||
| ggml_tensor * node = ggml_graph_node(gf, i); | ||
| if (node->op == GGML_OP_FLASH_ATTN_EXT) { | ||
| if (!ggml_backend_supports_op(ctx_clip.backend, node)) { | ||
| return false; | ||
| } | ||
| } | ||
| } | ||
|
Comment on lines
3326
to
3336
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tested this by temporary modify the code to always Currently, mtmd only support 2 backends at the same time: CPU and one GPU backend |
||
| return true; | ||
| } | ||
|
|
||
| void get_bool(const std::string & key, bool & output, bool required = true) { | ||
|
|
@@ -3312,14 +3363,14 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params | |
| ctx_vision = new clip_ctx(ctx_params); | ||
| loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION); | ||
| loader.load_tensors(*ctx_vision); | ||
| loader.alloc_compute_meta(*ctx_vision); | ||
| loader.warmup(*ctx_vision); | ||
| } | ||
|
|
||
| if (loader.has_audio) { | ||
| ctx_audio = new clip_ctx(ctx_params); | ||
| loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO); | ||
| loader.load_tensors(*ctx_audio); | ||
| loader.alloc_compute_meta(*ctx_audio); | ||
| loader.warmup(*ctx_audio); | ||
| } | ||
|
|
||
| } catch (const std::exception & e) { | ||
|
|
@@ -4485,6 +4536,15 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima | |
| return false; // only support batch size of 1 | ||
| } | ||
|
|
||
| if (ggml_backend_sched_get_n_splits(ctx->sched.get()) > 1) { | ||
| LOG_WRN("%s: *****************************************************************\n", __func__); | ||
| LOG_WRN("%s: WARNING: the CLIP graph uses unsupported operators by the backend\n", __func__); | ||
| LOG_WRN("%s: the performance will be suboptimal \n", __func__); | ||
| LOG_WRN("%s: \n", __func__); | ||
| LOG_WRN("%s: ref: https://github.com/ggml-org/llama.cpp/pull/16837#issuecomment-3461676118\n", __func__); | ||
| LOG_WRN("%s: *****************************************************************\n", __func__); | ||
| } | ||
|
|
||
| // build the inference graph | ||
| ctx->debug_print_tensors.clear(); | ||
| ggml_backend_sched_reset(ctx->sched.get()); | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ggerganov I implemented a simple solution to auto-enable flash attn only when the backend support it. Probably we should make this
LOG_WRNto be more prominent. Also, which kind of info do you think should be displayed here?Some users potentially already using models with shapes not supported by GPU flash attn. Falling back to CPU will suddenly make it very slow and thus not a good UX overall. The auto mode + prominent is a better solution as also it encourage users to "voluntary" report certain info back to us - less forcefully for them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can print the actual tensor (shape, strides, types) for which FA is not supported.