Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
if (op->src[0]->ne[0] != 32 &&
op->src[0]->ne[0] != 40 &&
op->src[0]->ne[0] != 64 &&
op->src[0]->ne[0] != 72 &&
op->src[0]->ne[0] != 80 &&
op->src[0]->ne[0] != 96 &&
op->src[0]->ne[0] != 112 &&
Expand Down
8 changes: 8 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -5362,6 +5362,7 @@ typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, hal
template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 72, 72>;
template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
Expand All @@ -5374,6 +5375,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 72, 72>;
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
Expand All @@ -5387,6 +5389,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 72, 72>;
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
Expand All @@ -5400,6 +5403,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 72, 72>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
Expand All @@ -5412,6 +5416,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 72, 72>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
Expand All @@ -5424,6 +5429,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 72, 72>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
Expand All @@ -5436,6 +5442,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 72, 72>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
Expand All @@ -5448,6 +5455,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 72, 72>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
Expand Down
4 changes: 2 additions & 2 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7225,8 +7225,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v));
}

for (int hsk : { 40, 64, 80, 96, 128, 192, 256, 576 }) {
for (int hsv : { 40, 64, 80, 96, 128, 192, 256, 512 }) {
for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) {
for (int hsv : { 40, 64, 72, 80, 96, 128, 192, 256, 512 }) {
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
Expand Down
80 changes: 70 additions & 10 deletions tools/mtmd/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
#include "clip.h"
#include "clip-impl.h"
#include "mtmd.h"
#include "ggml.h"
#include "ggml-cpp.h"
#include "ggml-cpu.h"
Expand All @@ -28,6 +29,7 @@
#include <numeric>
#include <functional>

// TODO: allow to pass callback from user code
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};

enum ffn_op_type {
Expand Down Expand Up @@ -426,12 +428,14 @@ struct clip_ctx {

int max_nodes = 8192;
ggml_backend_sched_ptr sched;
llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;

// for debugging
bool debug_graph = false;
std::vector<ggml_tensor *> debug_print_tensors;

clip_ctx(clip_context_params & ctx_params) {
flash_attn_type = ctx_params.flash_attn_type;
debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr;
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
if (!backend_cpu) {
Expand Down Expand Up @@ -2260,17 +2264,25 @@ struct clip_graph {
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
//cb(k, "k", il);

ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
v = ggml_cont(ctx0, v);
//cb(k, "v", il);

ggml_tensor * cur;

// TODO @ngxson : support flash attention
{
if (ctx->flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) {
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);

k = ggml_cast(ctx0, k, GGML_TYPE_F16);
v = ggml_cast(ctx0, v, GGML_TYPE_F16);

cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f);
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);

cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);

} else {
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
v = ggml_cont(ctx0, v);

const auto n_tokens = q->ne[1];
const auto n_head = q->ne[2];
// const auto n_kv = k->ne[1]; // for flash attention

ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
// F32 may not needed for vision encoders?
Expand Down Expand Up @@ -3192,7 +3204,30 @@ struct clip_model_loader {
}
}

void alloc_compute_meta(clip_ctx & ctx_clip) {
void warmup(clip_ctx & ctx_clip) {
if (ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
// try to enable flash attention to see if it's supported
ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
bool supported = alloc_compute_meta(ctx_clip);
if (!supported) {
LOG_WRN("%s: flash attention not supported, memory usage will increase\n", __func__);
// TODO: maybe log more details about why flash attention is not supported
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ggerganov I implemented a simple solution to auto-enable flash attn only when the backend support it. Probably we should make this LOG_WRN to be more prominent. Also, which kind of info do you think should be displayed here?

Some users potentially already using models with shapes not supported by GPU flash attn. Falling back to CPU will suddenly make it very slow and thus not a good UX overall. The auto mode + prominent is a better solution as also it encourage users to "voluntary" report certain info back to us - less forcefully for them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, which kind of info do you think should be displayed here?

We can print the actual tensor (shape, strides, types) for which FA is not supported.

ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
alloc_compute_meta(ctx_clip);
}
} else {
bool supported = alloc_compute_meta(ctx_clip);
if (!supported) {
LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__);
}
}

LOG_INF("%s: flash attention is %s\n", __func__,
(ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
}

// return false if flash attention is not supported
bool alloc_compute_meta(clip_ctx & ctx_clip) {
const auto & hparams = ctx_clip.model.hparams;
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());

Expand Down Expand Up @@ -3223,6 +3258,22 @@ struct clip_model_loader {
size / 1024.0 / 1024.0);
}
}

const int n_splits = ggml_backend_sched_get_n_splits(ctx_clip.sched.get());
const int n_nodes = ggml_graph_n_nodes(gf);

LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__, n_splits, n_nodes);

// check flash attention support
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
ggml_tensor * node = ggml_graph_node(gf, i);
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
if (!ggml_backend_supports_op(ctx_clip.backend, node)) {
return false;
}
}
}
Comment on lines 3326 to 3336
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this by temporary modify the code to always return false in ggml_metal_device_supports_op, works well for now, but I'm not sure if there are any edge cases.

Currently, mtmd only support 2 backends at the same time: CPU and one GPU backend

return true;
}

void get_bool(const std::string & key, bool & output, bool required = true) {
Expand Down Expand Up @@ -3312,14 +3363,14 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
ctx_vision = new clip_ctx(ctx_params);
loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION);
loader.load_tensors(*ctx_vision);
loader.alloc_compute_meta(*ctx_vision);
loader.warmup(*ctx_vision);
}

if (loader.has_audio) {
ctx_audio = new clip_ctx(ctx_params);
loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO);
loader.load_tensors(*ctx_audio);
loader.alloc_compute_meta(*ctx_audio);
loader.warmup(*ctx_audio);
}

} catch (const std::exception & e) {
Expand Down Expand Up @@ -4485,6 +4536,15 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
return false; // only support batch size of 1
}

if (ggml_backend_sched_get_n_splits(ctx->sched.get()) > 1) {
LOG_WRN("%s: *****************************************************************\n", __func__);
LOG_WRN("%s: WARNING: the CLIP graph uses unsupported operators by the backend\n", __func__);
LOG_WRN("%s: the performance will be suboptimal \n", __func__);
LOG_WRN("%s: \n", __func__);
LOG_WRN("%s: ref: https://github.com/ggml-org/llama.cpp/pull/16837#issuecomment-3461676118\n", __func__);
LOG_WRN("%s: *****************************************************************\n", __func__);
}

// build the inference graph
ctx->debug_print_tensors.clear();
ggml_backend_sched_reset(ctx->sched.get());
Expand Down
2 changes: 2 additions & 0 deletions tools/mtmd/clip.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "ggml.h"
#include "mtmd.h"
#include <stddef.h>
#include <stdint.h>

Expand All @@ -25,6 +26,7 @@ enum clip_modality {
struct clip_context_params {
bool use_gpu;
enum ggml_log_level verbosity;
llama_flash_attn_type flash_attn_type;
};

struct clip_init_result {
Expand Down
1 change: 1 addition & 0 deletions tools/mtmd/mtmd-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ struct mtmd_cli_context {
mparams.print_timings = true;
mparams.n_threads = params.cpuparams.n_threads;
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
mparams.flash_attn_type = params.flash_attn_type;
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
if (!ctx_vision.get()) {
LOG_ERR("Failed to load vision model from %s\n", clip_path);
Expand Down
Loading
Loading