44// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
55#include " clip.h"
66#include " clip-impl.h"
7- #include " mtmd.h"
87#include " ggml.h"
98#include " ggml-cpp.h"
10- #include " ggml-cpu.h"
119#include " ggml-alloc.h"
1210#include " ggml-backend.h"
1311#include " gguf.h"
1816#include < cstring>
1917#include < fstream>
2018#include < map>
21- #include < regex>
2219#include < stdexcept>
2320#include < unordered_set>
2421#include < vector>
25- #include < sstream>
2622#include < cinttypes>
2723#include < limits>
2824#include < array>
29- #include < numeric>
3025#include < functional>
3126
3227// TODO: allow to pass callback from user code
@@ -428,7 +423,7 @@ struct clip_ctx {
428423
429424 int max_nodes = 8192 ;
430425 ggml_backend_sched_ptr sched;
431- llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO ;
426+ clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO ;
432427
433428 // for debugging
434429 bool debug_graph = false ;
@@ -2266,7 +2261,7 @@ struct clip_graph {
22662261
22672262 ggml_tensor * cur;
22682263
2269- if (ctx->flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED ) {
2264+ if (ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED ) {
22702265 ggml_tensor * v = ggml_permute (ctx0, v_cur, 0 , 2 , 1 , 3 );
22712266
22722267 k = ggml_cast (ctx0, k, GGML_TYPE_F16);
@@ -3204,30 +3199,58 @@ struct clip_model_loader {
32043199 }
32053200 }
32063201
3207- void warmup (clip_ctx & ctx_clip) {
3208- if (ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
3202+ struct support_info_op {
3203+ ggml_tensor * op;
3204+
3205+ // true if the op runs on the accelerated ctx_clip.backend
3206+ bool is_accel = true ;
3207+ };
3208+
3209+ struct support_info_graph {
3210+ // whether the clip_ctx.backend supports flash attention
3211+ bool fattn = true ;
3212+
3213+ std::vector<support_info_op> ops;
3214+ };
3215+
3216+ static void warmup (clip_ctx & ctx_clip) {
3217+ support_info_graph info;
3218+
3219+ if (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_AUTO) {
32093220 // try to enable flash attention to see if it's supported
3210- ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED ;
3211- bool supported = alloc_compute_meta (ctx_clip);
3212- if (!supported ) {
3221+ ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_ENABLED ;
3222+ info = alloc_compute_meta (ctx_clip);
3223+ if (!info. fattn ) {
32133224 LOG_WRN (" %s: flash attention not supported, memory usage will increase\n " , __func__);
32143225 // TODO: maybe log more details about why flash attention is not supported
3215- ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED ;
3226+ ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_DISABLED ;
32163227 alloc_compute_meta (ctx_clip);
32173228 }
32183229 } else {
3219- bool supported = alloc_compute_meta (ctx_clip);
3220- if (!supported ) {
3230+ info = alloc_compute_meta (ctx_clip);
3231+ if (!info. fattn && ctx_clip. flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED ) {
32213232 LOG_WRN (" %s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n " , __func__);
32223233 }
32233234 }
32243235
32253236 LOG_INF (" %s: flash attention is %s\n " , __func__,
3226- (ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) ? " enabled" : " disabled" );
3237+ (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? " enabled" : " disabled" );
3238+
3239+ // print ops that are not supported by the GPU backend (if there is one)
3240+ if (ctx_clip.backend && ctx_clip.backend != ctx_clip.backend_cpu ) {
3241+ for (const auto & op : info.ops ) {
3242+ if (!op.is_accel ) {
3243+ LOG_WRN (" %s: op %16s is not supported by the CLIP backend: type = %s, ne = [%d %d %d %d]\n " , __func__,
3244+ ggml_op_name (op.op ->op ),
3245+ ggml_type_name (op.op ->type ),
3246+ op.op ->ne [0 ], op.op ->ne [1 ], op.op ->ne [2 ], op.op ->ne [3 ]);
3247+ }
3248+ }
3249+ }
32273250 }
32283251
3229- // return false if flash attention is not supported
3230- bool alloc_compute_meta (clip_ctx & ctx_clip) {
3252+ // return false if at least one op is not supported by the backend
3253+ static support_info_graph alloc_compute_meta (clip_ctx & ctx_clip) {
32313254 const auto & hparams = ctx_clip.model .hparams ;
32323255 ctx_clip.buf_compute_meta .resize (ctx_clip.max_nodes * ggml_tensor_overhead () + ggml_graph_overhead ());
32333256
@@ -3264,67 +3287,87 @@ struct clip_model_loader {
32643287
32653288 LOG_INF (" %s: graph splits = %d, nodes = %d\n " , __func__, n_splits, n_nodes);
32663289
3267- // check flash attention support
3290+ support_info_graph res {
3291+ /* .fattn = */ true ,
3292+ /* .ops = */ {},
3293+ };
3294+
3295+ // check op support
32683296 for (int i = 0 ; i < ggml_graph_n_nodes (gf); i++) {
32693297 ggml_tensor * node = ggml_graph_node (gf, i);
3270- if (node->op == GGML_OP_FLASH_ATTN_EXT) {
3271- if (!ggml_backend_supports_op (ctx_clip.backend , node)) {
3272- return false ;
3298+ res.ops .push_back ({node, true });
3299+ if (!ggml_backend_supports_op (ctx_clip.backend , node)) {
3300+ res.ops .back ().is_accel = false ;
3301+ if (node->op == GGML_OP_FLASH_ATTN_EXT) {
3302+ res.fattn = false ;
32733303 }
32743304 }
32753305 }
3276- return true ;
3306+
3307+ return res;
32773308 }
32783309
3279- void get_bool (const std::string & key, bool & output, bool required = true ) {
3310+ void get_bool (const std::string & key, bool & output, bool required = true ) const {
32803311 const int i = gguf_find_key (ctx_gguf.get (), key.c_str ());
32813312 if (i < 0 ) {
3282- if (required) throw std::runtime_error (" Key not found: " + key);
3313+ if (required) {
3314+ throw std::runtime_error (" Key not found: " + key);
3315+ }
32833316 return ;
32843317 }
32853318 output = gguf_get_val_bool (ctx_gguf.get (), i);
32863319 }
32873320
3288- void get_i32 (const std::string & key, int & output, bool required = true ) {
3321+ void get_i32 (const std::string & key, int & output, bool required = true ) const {
32893322 const int i = gguf_find_key (ctx_gguf.get (), key.c_str ());
32903323 if (i < 0 ) {
3291- if (required) throw std::runtime_error (" Key not found: " + key);
3324+ if (required) {
3325+ throw std::runtime_error (" Key not found: " + key);
3326+ }
32923327 return ;
32933328 }
32943329 output = gguf_get_val_i32 (ctx_gguf.get (), i);
32953330 }
32963331
3297- void get_u32 (const std::string & key, int & output, bool required = true ) {
3332+ void get_u32 (const std::string & key, int & output, bool required = true ) const {
32983333 const int i = gguf_find_key (ctx_gguf.get (), key.c_str ());
32993334 if (i < 0 ) {
3300- if (required) throw std::runtime_error (" Key not found: " + key);
3335+ if (required) {
3336+ throw std::runtime_error (" Key not found: " + key);
3337+ }
33013338 return ;
33023339 }
33033340 output = gguf_get_val_u32 (ctx_gguf.get (), i);
33043341 }
33053342
3306- void get_f32 (const std::string & key, float & output, bool required = true ) {
3343+ void get_f32 (const std::string & key, float & output, bool required = true ) const {
33073344 const int i = gguf_find_key (ctx_gguf.get (), key.c_str ());
33083345 if (i < 0 ) {
3309- if (required) throw std::runtime_error (" Key not found: " + key);
3346+ if (required) {
3347+ throw std::runtime_error (" Key not found: " + key);
3348+ }
33103349 return ;
33113350 }
33123351 output = gguf_get_val_f32 (ctx_gguf.get (), i);
33133352 }
33143353
3315- void get_string (const std::string & key, std::string & output, bool required = true ) {
3354+ void get_string (const std::string & key, std::string & output, bool required = true ) const {
33163355 const int i = gguf_find_key (ctx_gguf.get (), key.c_str ());
33173356 if (i < 0 ) {
3318- if (required) throw std::runtime_error (" Key not found: " + key);
3357+ if (required) {
3358+ throw std::runtime_error (" Key not found: " + key);
3359+ }
33193360 return ;
33203361 }
33213362 output = std::string (gguf_get_val_str (ctx_gguf.get (), i));
33223363 }
33233364
3324- void get_arr_int (const std::string & key, std::vector<int > & output, bool required = true ) {
3365+ void get_arr_int (const std::string & key, std::vector<int > & output, bool required = true ) const {
33253366 const int i = gguf_find_key (ctx_gguf.get (), key.c_str ());
33263367 if (i < 0 ) {
3327- if (required) throw std::runtime_error (" Key not found: " + key);
3368+ if (required) {
3369+ throw std::runtime_error (" Key not found: " + key);
3370+ }
33283371 return ;
33293372 }
33303373 int n = gguf_get_arr_n (ctx_gguf.get (), i);
@@ -3335,7 +3378,7 @@ struct clip_model_loader {
33353378 }
33363379 }
33373380
3338- void set_llava_uhd_res_candidates (clip_model & model, const int max_patches_per_side) {
3381+ static void set_llava_uhd_res_candidates (clip_model & model, const int max_patches_per_side) {
33393382 auto & hparams = model.hparams ;
33403383 for (int x = 1 ; x <= max_patches_per_side; x++) {
33413384 for (int y = 1 ; y <= max_patches_per_side; y++) {
@@ -3375,12 +3418,10 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
33753418
33763419 } catch (const std::exception & e) {
33773420 LOG_ERR (" %s: failed to load model '%s': %s\n " , __func__, fname, e.what ());
3378- if (ctx_vision) {
3379- delete ctx_vision;
3380- }
3381- if (ctx_audio) {
3382- delete ctx_audio;
3383- }
3421+
3422+ delete ctx_vision;
3423+ delete ctx_audio;
3424+
33843425 return {nullptr , nullptr };
33853426 }
33863427
@@ -3418,10 +3459,10 @@ void clip_image_size_free(struct clip_image_size * load_image_size) {
34183459 }
34193460 delete load_image_size;
34203461}
3421- void clip_image_u8_free (struct clip_image_u8 * img) { if (img) delete img; }
3422- void clip_image_f32_free (struct clip_image_f32 * img) { if (img) delete img; }
3423- void clip_image_u8_batch_free (struct clip_image_u8_batch * batch) { if (batch) delete batch; }
3424- void clip_image_f32_batch_free (struct clip_image_f32_batch * batch) { if (batch) delete batch; }
3462+ void clip_image_u8_free (struct clip_image_u8 * img) { delete img; }
3463+ void clip_image_f32_free (struct clip_image_f32 * img) { delete img; }
3464+ void clip_image_u8_batch_free (struct clip_image_u8_batch * batch) { delete batch; }
3465+ void clip_image_f32_batch_free (struct clip_image_f32_batch * batch) { delete batch; }
34253466
34263467size_t clip_image_f32_batch_n_images (const struct clip_image_f32_batch * batch) {
34273468 return batch->entries .size ();
@@ -4539,6 +4580,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
45394580 if (ggml_backend_sched_get_n_splits (ctx->sched .get ()) > 1 ) {
45404581 LOG_WRN (" %s: *****************************************************************\n " , __func__);
45414582 LOG_WRN (" %s: WARNING: the CLIP graph uses unsupported operators by the backend\n " , __func__);
4583+ LOG_WRN (" %s: use GGML_SCHED_DEBUG=2 to determine which ops \n " , __func__);
45424584 LOG_WRN (" %s: the performance will be suboptimal \n " , __func__);
45434585 LOG_WRN (" %s: \n " , __func__);
45444586 LOG_WRN (" %s: ref: https://github.com/ggml-org/llama.cpp/pull/16837#issuecomment-3461676118\n " , __func__);
0 commit comments