|
4 | 4 | // Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch |
5 | 5 | #include "clip.h" |
6 | 6 | #include "clip-impl.h" |
| 7 | +#include "mtmd.h" |
7 | 8 | #include "ggml.h" |
8 | 9 | #include "ggml-cpp.h" |
9 | 10 | #include "ggml-cpu.h" |
@@ -427,12 +428,14 @@ struct clip_ctx { |
427 | 428 |
|
428 | 429 | int max_nodes = 8192; |
429 | 430 | ggml_backend_sched_ptr sched; |
| 431 | + llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; |
430 | 432 |
|
431 | 433 | // for debugging |
432 | 434 | bool debug_graph = false; |
433 | 435 | std::vector<ggml_tensor *> debug_print_tensors; |
434 | 436 |
|
435 | 437 | clip_ctx(clip_context_params & ctx_params) { |
| 438 | + flash_attn_type = ctx_params.flash_attn_type; |
436 | 439 | debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr; |
437 | 440 | backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); |
438 | 441 | if (!backend_cpu) { |
@@ -2261,16 +2264,36 @@ struct clip_graph { |
2261 | 2264 | ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3); |
2262 | 2265 | //cb(k, "k", il); |
2263 | 2266 |
|
2264 | | - ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); |
2265 | | - //cb(k, "v", il); |
| 2267 | + ggml_tensor * cur; |
2266 | 2268 |
|
2267 | | - k = ggml_cast(ctx0, k, GGML_TYPE_F16); |
2268 | | - v = ggml_cast(ctx0, v, GGML_TYPE_F16); |
| 2269 | + if (ctx->flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) { |
| 2270 | + ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); |
2269 | 2271 |
|
2270 | | - ggml_tensor * cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f); |
2271 | | - ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); |
| 2272 | + k = ggml_cast(ctx0, k, GGML_TYPE_F16); |
| 2273 | + v = ggml_cast(ctx0, v, GGML_TYPE_F16); |
2272 | 2274 |
|
2273 | | - cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); |
| 2275 | + cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f); |
| 2276 | + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); |
| 2277 | + |
| 2278 | + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); |
| 2279 | + |
| 2280 | + } else { |
| 2281 | + ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3); |
| 2282 | + v = ggml_cont(ctx0, v); |
| 2283 | + |
| 2284 | + const auto n_tokens = q->ne[1]; |
| 2285 | + const auto n_head = q->ne[2]; |
| 2286 | + |
| 2287 | + ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); |
| 2288 | + // F32 may not needed for vision encoders? |
| 2289 | + // ggml_mul_mat_set_prec(kq, GGML_PREC_F32); |
| 2290 | + |
| 2291 | + kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f); |
| 2292 | + |
| 2293 | + ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); |
| 2294 | + cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); |
| 2295 | + cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); |
| 2296 | + } |
2274 | 2297 |
|
2275 | 2298 | cb(cur, "kqv_out", il); |
2276 | 2299 |
|
@@ -3181,7 +3204,30 @@ struct clip_model_loader { |
3181 | 3204 | } |
3182 | 3205 | } |
3183 | 3206 |
|
3184 | | - void alloc_compute_meta(clip_ctx & ctx_clip) { |
| 3207 | + void warmup(clip_ctx & ctx_clip) { |
| 3208 | + if (ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { |
| 3209 | + // try to enable flash attention to see if it's supported |
| 3210 | + ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED; |
| 3211 | + bool supported = alloc_compute_meta(ctx_clip); |
| 3212 | + if (!supported) { |
| 3213 | + LOG_WRN("%s: flash attention not supported, memory usage will increase\n", __func__); |
| 3214 | + // TODO: maybe log more details about why flash attention is not supported |
| 3215 | + ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; |
| 3216 | + alloc_compute_meta(ctx_clip); |
| 3217 | + } |
| 3218 | + } else { |
| 3219 | + bool supported = alloc_compute_meta(ctx_clip); |
| 3220 | + if (!supported) { |
| 3221 | + LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__); |
| 3222 | + } |
| 3223 | + } |
| 3224 | + |
| 3225 | + LOG_INF("%s: flash attention is %s\n", __func__, |
| 3226 | + (ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled"); |
| 3227 | + } |
| 3228 | + |
| 3229 | + // return false if flash attention is not supported |
| 3230 | + bool alloc_compute_meta(clip_ctx & ctx_clip) { |
3185 | 3231 | const auto & hparams = ctx_clip.model.hparams; |
3186 | 3232 | ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead()); |
3187 | 3233 |
|
@@ -3217,6 +3263,17 @@ struct clip_model_loader { |
3217 | 3263 | const int n_nodes = ggml_graph_n_nodes(gf); |
3218 | 3264 |
|
3219 | 3265 | LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__, n_splits, n_nodes); |
| 3266 | + |
| 3267 | + // check flash attention support |
| 3268 | + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { |
| 3269 | + ggml_tensor * node = ggml_graph_node(gf, i); |
| 3270 | + if (node->op == GGML_OP_FLASH_ATTN_EXT) { |
| 3271 | + if (!ggml_backend_supports_op(ctx_clip.backend, node)) { |
| 3272 | + return false; |
| 3273 | + } |
| 3274 | + } |
| 3275 | + } |
| 3276 | + return true; |
3220 | 3277 | } |
3221 | 3278 |
|
3222 | 3279 | void get_bool(const std::string & key, bool & output, bool required = true) { |
@@ -3306,14 +3363,14 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params |
3306 | 3363 | ctx_vision = new clip_ctx(ctx_params); |
3307 | 3364 | loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION); |
3308 | 3365 | loader.load_tensors(*ctx_vision); |
3309 | | - loader.alloc_compute_meta(*ctx_vision); |
| 3366 | + loader.warmup(*ctx_vision); |
3310 | 3367 | } |
3311 | 3368 |
|
3312 | 3369 | if (loader.has_audio) { |
3313 | 3370 | ctx_audio = new clip_ctx(ctx_params); |
3314 | 3371 | loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO); |
3315 | 3372 | loader.load_tensors(*ctx_audio); |
3316 | | - loader.alloc_compute_meta(*ctx_audio); |
| 3373 | + loader.warmup(*ctx_audio); |
3317 | 3374 | } |
3318 | 3375 |
|
3319 | 3376 | } catch (const std::exception & e) { |
|
0 commit comments