Skip to content

Commit bdb43f6

Browse files
committed
clip : print more detailed op support info during warmup
1 parent b4955f0 commit bdb43f6

File tree

4 files changed

+108
-54
lines changed

4 files changed

+108
-54
lines changed

tools/mtmd/clip.cpp

Lines changed: 88 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@
44
// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
55
#include "clip.h"
66
#include "clip-impl.h"
7-
#include "mtmd.h"
87
#include "ggml.h"
98
#include "ggml-cpp.h"
10-
#include "ggml-cpu.h"
119
#include "ggml-alloc.h"
1210
#include "ggml-backend.h"
1311
#include "gguf.h"
@@ -18,15 +16,12 @@
1816
#include <cstring>
1917
#include <fstream>
2018
#include <map>
21-
#include <regex>
2219
#include <stdexcept>
2320
#include <unordered_set>
2421
#include <vector>
25-
#include <sstream>
2622
#include <cinttypes>
2723
#include <limits>
2824
#include <array>
29-
#include <numeric>
3025
#include <functional>
3126

3227
// TODO: allow to pass callback from user code
@@ -428,7 +423,7 @@ struct clip_ctx {
428423

429424
int max_nodes = 8192;
430425
ggml_backend_sched_ptr sched;
431-
llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
426+
clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO;
432427

433428
// for debugging
434429
bool debug_graph = false;
@@ -2266,7 +2261,7 @@ struct clip_graph {
22662261

22672262
ggml_tensor * cur;
22682263

2269-
if (ctx->flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) {
2264+
if (ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
22702265
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
22712266

22722267
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
@@ -3204,30 +3199,58 @@ struct clip_model_loader {
32043199
}
32053200
}
32063201

3207-
void warmup(clip_ctx & ctx_clip) {
3208-
if (ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
3202+
struct support_info_op {
3203+
ggml_tensor * op;
3204+
3205+
// true if the op runs on the accelerated ctx_clip.backend
3206+
bool is_accel = true;
3207+
};
3208+
3209+
struct support_info_graph {
3210+
// whether the clip_ctx.backend supports flash attention
3211+
bool fattn = true;
3212+
3213+
std::vector<support_info_op> ops;
3214+
};
3215+
3216+
static void warmup(clip_ctx & ctx_clip) {
3217+
support_info_graph info;
3218+
3219+
if (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_AUTO) {
32093220
// try to enable flash attention to see if it's supported
3210-
ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
3211-
bool supported = alloc_compute_meta(ctx_clip);
3212-
if (!supported) {
3221+
ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_ENABLED;
3222+
info = alloc_compute_meta(ctx_clip);
3223+
if (!info.fattn) {
32133224
LOG_WRN("%s: flash attention not supported, memory usage will increase\n", __func__);
32143225
// TODO: maybe log more details about why flash attention is not supported
3215-
ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
3226+
ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_DISABLED;
32163227
alloc_compute_meta(ctx_clip);
32173228
}
32183229
} else {
3219-
bool supported = alloc_compute_meta(ctx_clip);
3220-
if (!supported) {
3230+
info = alloc_compute_meta(ctx_clip);
3231+
if (!info.fattn && ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
32213232
LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__);
32223233
}
32233234
}
32243235

32253236
LOG_INF("%s: flash attention is %s\n", __func__,
3226-
(ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
3237+
(ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
3238+
3239+
// print ops that are not supported by the GPU backend (if there is one)
3240+
if (ctx_clip.backend && ctx_clip.backend != ctx_clip.backend_cpu) {
3241+
for (const auto & op : info.ops) {
3242+
if (!op.is_accel) {
3243+
LOG_WRN("%s: op %16s is not supported by the CLIP backend: type = %s, ne = [%d %d %d %d]\n", __func__,
3244+
ggml_op_name(op.op->op),
3245+
ggml_type_name(op.op->type),
3246+
op.op->ne[0], op.op->ne[1], op.op->ne[2], op.op->ne[3]);
3247+
}
3248+
}
3249+
}
32273250
}
32283251

3229-
// return false if flash attention is not supported
3230-
bool alloc_compute_meta(clip_ctx & ctx_clip) {
3252+
// return false if at least one op is not supported by the backend
3253+
static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip) {
32313254
const auto & hparams = ctx_clip.model.hparams;
32323255
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
32333256

@@ -3264,67 +3287,87 @@ struct clip_model_loader {
32643287

32653288
LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__, n_splits, n_nodes);
32663289

3267-
// check flash attention support
3290+
support_info_graph res {
3291+
/*.fattn = */ true,
3292+
/*.ops = */ {},
3293+
};
3294+
3295+
// check op support
32683296
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
32693297
ggml_tensor * node = ggml_graph_node(gf, i);
3270-
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
3271-
if (!ggml_backend_supports_op(ctx_clip.backend, node)) {
3272-
return false;
3298+
res.ops.push_back({node, true});
3299+
if (!ggml_backend_supports_op(ctx_clip.backend, node)) {
3300+
res.ops.back().is_accel = false;
3301+
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
3302+
res.fattn = false;
32733303
}
32743304
}
32753305
}
3276-
return true;
3306+
3307+
return res;
32773308
}
32783309

3279-
void get_bool(const std::string & key, bool & output, bool required = true) {
3310+
void get_bool(const std::string & key, bool & output, bool required = true) const {
32803311
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
32813312
if (i < 0) {
3282-
if (required) throw std::runtime_error("Key not found: " + key);
3313+
if (required) {
3314+
throw std::runtime_error("Key not found: " + key);
3315+
}
32833316
return;
32843317
}
32853318
output = gguf_get_val_bool(ctx_gguf.get(), i);
32863319
}
32873320

3288-
void get_i32(const std::string & key, int & output, bool required = true) {
3321+
void get_i32(const std::string & key, int & output, bool required = true) const {
32893322
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
32903323
if (i < 0) {
3291-
if (required) throw std::runtime_error("Key not found: " + key);
3324+
if (required) {
3325+
throw std::runtime_error("Key not found: " + key);
3326+
}
32923327
return;
32933328
}
32943329
output = gguf_get_val_i32(ctx_gguf.get(), i);
32953330
}
32963331

3297-
void get_u32(const std::string & key, int & output, bool required = true) {
3332+
void get_u32(const std::string & key, int & output, bool required = true) const {
32983333
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
32993334
if (i < 0) {
3300-
if (required) throw std::runtime_error("Key not found: " + key);
3335+
if (required) {
3336+
throw std::runtime_error("Key not found: " + key);
3337+
}
33013338
return;
33023339
}
33033340
output = gguf_get_val_u32(ctx_gguf.get(), i);
33043341
}
33053342

3306-
void get_f32(const std::string & key, float & output, bool required = true) {
3343+
void get_f32(const std::string & key, float & output, bool required = true) const {
33073344
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
33083345
if (i < 0) {
3309-
if (required) throw std::runtime_error("Key not found: " + key);
3346+
if (required) {
3347+
throw std::runtime_error("Key not found: " + key);
3348+
}
33103349
return;
33113350
}
33123351
output = gguf_get_val_f32(ctx_gguf.get(), i);
33133352
}
33143353

3315-
void get_string(const std::string & key, std::string & output, bool required = true) {
3354+
void get_string(const std::string & key, std::string & output, bool required = true) const {
33163355
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
33173356
if (i < 0) {
3318-
if (required) throw std::runtime_error("Key not found: " + key);
3357+
if (required) {
3358+
throw std::runtime_error("Key not found: " + key);
3359+
}
33193360
return;
33203361
}
33213362
output = std::string(gguf_get_val_str(ctx_gguf.get(), i));
33223363
}
33233364

3324-
void get_arr_int(const std::string & key, std::vector<int> & output, bool required = true) {
3365+
void get_arr_int(const std::string & key, std::vector<int> & output, bool required = true) const {
33253366
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
33263367
if (i < 0) {
3327-
if (required) throw std::runtime_error("Key not found: " + key);
3368+
if (required) {
3369+
throw std::runtime_error("Key not found: " + key);
3370+
}
33283371
return;
33293372
}
33303373
int n = gguf_get_arr_n(ctx_gguf.get(), i);
@@ -3335,7 +3378,7 @@ struct clip_model_loader {
33353378
}
33363379
}
33373380

3338-
void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) {
3381+
static void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) {
33393382
auto & hparams = model.hparams;
33403383
for (int x = 1; x <= max_patches_per_side; x++) {
33413384
for (int y = 1; y <= max_patches_per_side; y++) {
@@ -3375,12 +3418,10 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
33753418

33763419
} catch (const std::exception & e) {
33773420
LOG_ERR("%s: failed to load model '%s': %s\n", __func__, fname, e.what());
3378-
if (ctx_vision) {
3379-
delete ctx_vision;
3380-
}
3381-
if (ctx_audio) {
3382-
delete ctx_audio;
3383-
}
3421+
3422+
delete ctx_vision;
3423+
delete ctx_audio;
3424+
33843425
return {nullptr, nullptr};
33853426
}
33863427

@@ -3418,10 +3459,10 @@ void clip_image_size_free(struct clip_image_size * load_image_size) {
34183459
}
34193460
delete load_image_size;
34203461
}
3421-
void clip_image_u8_free(struct clip_image_u8 * img) { if (img) delete img; }
3422-
void clip_image_f32_free(struct clip_image_f32 * img) { if (img) delete img; }
3423-
void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { if (batch) delete batch; }
3424-
void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { if (batch) delete batch; }
3462+
void clip_image_u8_free(struct clip_image_u8 * img) { delete img; }
3463+
void clip_image_f32_free(struct clip_image_f32 * img) { delete img; }
3464+
void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { delete batch; }
3465+
void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { delete batch; }
34253466

34263467
size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch) {
34273468
return batch->entries.size();
@@ -4539,6 +4580,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
45394580
if (ggml_backend_sched_get_n_splits(ctx->sched.get()) > 1) {
45404581
LOG_WRN("%s: *****************************************************************\n", __func__);
45414582
LOG_WRN("%s: WARNING: the CLIP graph uses unsupported operators by the backend\n", __func__);
4583+
LOG_WRN("%s: use GGML_SCHED_DEBUG=2 to determine which ops \n", __func__);
45424584
LOG_WRN("%s: the performance will be suboptimal \n", __func__);
45434585
LOG_WRN("%s: \n", __func__);
45444586
LOG_WRN("%s: ref: https://github.com/ggml-org/llama.cpp/pull/16837#issuecomment-3461676118\n", __func__);

tools/mtmd/clip.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

33
#include "ggml.h"
4-
#include "mtmd.h"
4+
55
#include <stddef.h>
66
#include <stdint.h>
77

@@ -23,10 +23,16 @@ enum clip_modality {
2323
CLIP_MODALITY_AUDIO,
2424
};
2525

26+
enum clip_flash_attn_type {
27+
CLIP_FLASH_ATTN_TYPE_AUTO = -1,
28+
CLIP_FLASH_ATTN_TYPE_DISABLED = 0,
29+
CLIP_FLASH_ATTN_TYPE_ENABLED = 1,
30+
};
31+
2632
struct clip_context_params {
2733
bool use_gpu;
2834
enum ggml_log_level verbosity;
29-
llama_flash_attn_type flash_attn_type;
35+
enum clip_flash_attn_type flash_attn_type;
3036
};
3137

3238
struct clip_init_result {

tools/mtmd/mtmd.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
#include <cstdio>
2020
#include <cstdlib>
2121
#include <cstring>
22-
#include <limits>
2322
#include <vector>
2423

2524
// represents raw image data, layout is RGBRGBRGB...
@@ -92,6 +91,15 @@ const char * mtmd_default_marker() {
9291
return "<__media__>";
9392
}
9493

94+
static clip_flash_attn_type mtmd_get_clip_flash_attn_type(enum llama_flash_attn_type flash_attn_type) {
95+
switch (flash_attn_type) {
96+
case LLAMA_FLASH_ATTN_TYPE_AUTO: return CLIP_FLASH_ATTN_TYPE_AUTO;
97+
case LLAMA_FLASH_ATTN_TYPE_DISABLED: return CLIP_FLASH_ATTN_TYPE_DISABLED;
98+
case LLAMA_FLASH_ATTN_TYPE_ENABLED: return CLIP_FLASH_ATTN_TYPE_ENABLED;
99+
}
100+
return CLIP_FLASH_ATTN_TYPE_AUTO;
101+
}
102+
95103
mtmd_context_params mtmd_context_params_default() {
96104
mtmd_context_params params;
97105
params.use_gpu = true;
@@ -165,7 +173,7 @@ struct mtmd_context {
165173
clip_context_params ctx_clip_params;
166174
ctx_clip_params.use_gpu = ctx_params.use_gpu;
167175
ctx_clip_params.verbosity = ctx_params.verbosity;
168-
ctx_clip_params.flash_attn_type = ctx_params.flash_attn_type;
176+
ctx_clip_params.flash_attn_type = mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type);
169177
auto res = clip_init(mmproj_fname, ctx_clip_params);
170178
ctx_v = res.ctx_v;
171179
ctx_a = res.ctx_a;
@@ -380,9 +388,7 @@ mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
380388
}
381389

382390
void mtmd_free(mtmd_context * ctx) {
383-
if (ctx) {
384-
delete ctx;
385-
}
391+
delete ctx;
386392
}
387393

388394
struct mtmd_tokenizer {

tools/mtmd/mtmd.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ struct mtmd_context_params {
8282
enum ggml_log_level verbosity;
8383
const char * image_marker; // deprecated, use media_marker instead
8484
const char * media_marker;
85-
llama_flash_attn_type flash_attn_type;
85+
enum llama_flash_attn_type flash_attn_type;
8686
};
8787

8888
MTMD_API const char * mtmd_default_marker(void);

0 commit comments

Comments
 (0)