Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.mmproj_use_gpu = false;
return true;
}
if (arg == "--mtmd-kq-type") {
CHECK_ARG
params.mtmd_kq_type = argv[i];
return true;
}
if (arg == "--image" || arg == "--audio") {
CHECK_ARG
params.image.emplace_back(argv[i]);
Expand Down Expand Up @@ -2489,9 +2494,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "multi-modality" });
options.push_back({ "*", " --mmproj FILE", "path to a multimodal projector file for LLaVA. see examples/llava/README.md" });
options.push_back({ "*", " --image FILE", "path to an image file. use with multimodal models. Specify multiple times for batching" });
options.push_back({ "*", " --image-min-tokens N", "minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)"});
options.push_back({ "*", " --image-max-tokens N", "maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)" });
options.push_back({ "*", " --no-context-shift", "disable context-shift." });
options.push_back({ "*", " --image-min-tokens N", "minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)"});
options.push_back({ "*", " --image-max-tokens N", "maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)" });
options.push_back({ "*", " --mtmd-kq-type TYPE", "data type for multimodality K*Q (default: %s)", params.mtmd_kq_type.c_str() });
options.push_back({ "*", " --no-context-shift", "disable context-shift." });
options.push_back({ "*", "--context-shift (auto|on|off|0|1)", "set context-shift (default: %s)", params.ctx_shift ? "on" : "off" });
options.push_back({ "backend" });
options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" });
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ struct gpt_params {
std::vector<std::string> image; // path to image file(s)
int image_min_tokens = -1;
int image_max_tokens = -1;
std::string mtmd_kq_type = "f32";

// embedding
bool embedding = false; // get only sentence embedding
Expand Down
28 changes: 28 additions & 0 deletions examples/mtmd/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,13 +443,15 @@ struct clip_ctx {
int max_nodes = 8192;
ggml_backend_sched_ptr sched;
clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO;
ggml_type kq_type = GGML_TYPE_F32;

// for debugging
bool debug_graph = false;
std::vector<ggml_tensor *> debug_print_tensors;

clip_ctx(clip_context_params & ctx_params) {
flash_attn_type = ctx_params.flash_attn_type;
kq_type = ctx_params.kq_type;
debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr;
//backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
backend_cpu = ggml_backend_cpu_init();
Expand Down Expand Up @@ -1011,7 +1013,9 @@ struct clip_graph {
// self-attention
{
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
cb(cur, "qkv_w", il);
cur = ggml_add(ctx0, cur, layer.qkv_b);
cb(cur, "qkv_b", il);

ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
cur->nb[1], 0);
Expand Down Expand Up @@ -1072,12 +1076,14 @@ struct clip_graph {
nullptr, nullptr,
layer.deepstack_fc2_w, layer.deepstack_fc2_b,
ffn_op_type::FFN_GELU, il);
cb(feat, "ffn_feat", il);

if(!deepstack_features) {
deepstack_features = feat;
} else {
// concat along the feature dimension
deepstack_features = ggml_concat(ctx0, deepstack_features, feat, 0);
cb(deepstack_features, "feat_concat", il);
}
}

Expand All @@ -1098,9 +1104,11 @@ struct clip_graph {
nullptr, nullptr,
model.mm_1_w, model.mm_1_b,
ffn_op_type::FFN_GELU, -1);
cb(embeddings, "ffn_postl", -1);

if (deepstack_features) {
embeddings = ggml_concat(ctx0, embeddings, deepstack_features, 0); // concat along the feature dimension
cb(embeddings, "ffn_postl_concat", -1);
}

// build the graph
Expand Down Expand Up @@ -2425,17 +2433,37 @@ struct clip_graph {
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
v = ggml_cont(ctx0, v);

if (ctx->kq_type != k->type) {
auto bs = ggml_blck_size(ctx->kq_type);
if (k->ne[0] % bs != 0) {
int nbs = bs*((k->ne[0] + bs - 1)/bs);
k = ggml_pad(ctx0, k, nbs - k->ne[0], 0, 0, 0);
}
if (q->ne[0] % bs != 0) {
int nbs = bs*((q->ne[0] + bs - 1)/bs);
q = ggml_pad(ctx0, q, nbs - q->ne[0], 0, 0, 0);
}
k = ggml_cast(ctx0, k, ctx->kq_type);
if (!ggml_is_quantized(ctx->kq_type)) {
q = ggml_cast(ctx0, q, ctx->kq_type);
}
}

if (q->ne[3] == 1 && q->ne[2] > 1 && q->ne[2] == k->ne[2] && q->ne[2] == v->ne[2] && q->ne[1]*k->ne[1]*q->ne[2]/1024./1024. >= 256.) {
cur = nullptr;
for (int64_t i2 = 0; i2 < q->ne[2]; ++i2) {
auto qi = ggml_view_2d(ctx0, q, q->ne[0], q->ne[1], q->nb[1], q->nb[2]*i2);
auto ki = ggml_view_2d(ctx0, k, k->ne[0], k->ne[1], k->nb[1], k->nb[2]*i2);
auto vi = ggml_view_2d(ctx0, v, v->ne[0], v->ne[1], v->nb[1], v->nb[2]*i2);
auto kq = ggml_mul_mat(ctx0, ki, qi);
cb(kq, "kq_i", il);
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f);
cb(kq, "softmax(kq_i)", il);
auto kqv = ggml_mul_mat(ctx0, vi, kq);
cb(kqv, "kqv_i", il);
if (cur) {
cur = ggml_concat(ctx0, cur, kqv, 0);
cb(cur, "kqv_i_concat", il);
} else {
cur = kqv;
}
Expand Down
1 change: 1 addition & 0 deletions examples/mtmd/clip.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct clip_context_params {
enum clip_flash_attn_type flash_attn_type;
int image_min_tokens;
int image_max_tokens;
ggml_type kq_type;
};

struct clip_init_result {
Expand Down
17 changes: 17 additions & 0 deletions examples/mtmd/mtmd-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,22 @@ void common_init() {
#endif
// ======================= end compat ================================

static ggml_type ggml_type_from_str(const std::string & s) {
if (s == "f32") {
return GGML_TYPE_F32;
}
if (s == "f16") {
return GGML_TYPE_F16;
}
if (s == "bf16") {
return GGML_TYPE_BF16;
}
if (s == "q8_0") {
return GGML_TYPE_Q8_0;
}
throw std::runtime_error("Invalid cache type: " + s);
}

struct mtmd_cli_context {
mtmd::context_ptr ctx_vision;
common_init_result llama_init;
Expand Down Expand Up @@ -171,6 +187,7 @@ struct mtmd_cli_context {
mparams.flash_attn_type = params.flash_attn ? LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED;
mparams.image_min_tokens = params.image_min_tokens;
mparams.image_max_tokens = params.image_max_tokens;
mparams.kq_type = ggml_type_from_str(params.mtmd_kq_type);
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
if (!ctx_vision.get()) {
LOG_ERR("Failed to load vision model from %s\n", clip_path);
Expand Down
2 changes: 2 additions & 0 deletions examples/mtmd/mtmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ mtmd_context_params mtmd_context_params_default() {
/* flash_attn_type */ LLAMA_FLASH_ATTN_TYPE_AUTO,
/* image_min_tokens */ -1,
/* image_max_tokens */ -1,
/* kq_type */ GGML_TYPE_F32,
};
return params;
}
Expand Down Expand Up @@ -170,6 +171,7 @@ struct mtmd_context {
/* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_DISABLED,
/* image_min_tokens */ ctx_params.image_min_tokens,
/* image_max_tokens */ ctx_params.image_max_tokens,
/* kq_type */ ctx_params.kq_type,
};

auto res = clip_init(mmproj_fname, ctx_clip_params);
Expand Down
1 change: 1 addition & 0 deletions examples/mtmd/mtmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ struct mtmd_context_params {
// limit number of image tokens, only for vision models with dynamic resolution
int image_min_tokens; // minimum number of tokens for image input (default: read from metadata)
int image_max_tokens; // maximum number of tokens for image input (default: read from metadata)
ggml_type kq_type;
};

MTMD_API const char * mtmd_default_marker(void);
Expand Down
15 changes: 11 additions & 4 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -3971,12 +3971,15 @@ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x,
int i = 0;
ggml_float sum = 0;
#if defined(__AVX512F__) && defined(__AVX512DQ__)
__m512 vsum = _mm512_setzero_ps();
for (; i + 15 < n; i += 16) {
__m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
_mm512_set1_ps(max)));
_mm512_storeu_ps(y + i, val);
sum += (ggml_float)_mm512_reduce_add_ps(val);
vsum = _mm512_add_ps(vsum, val);
//sum += (ggml_float)_mm512_reduce_add_ps(val);
}
sum = (ggml_float)_mm512_reduce_add_ps(vsum);
#elif defined(__AVX2__) && defined(__FMA__)
for (; i + 7 < n; i += 8) {
__m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
Expand Down Expand Up @@ -21667,16 +21670,20 @@ static void ggml_compute_forward_flash_attn_ext_f16(

#if GGML_USE_IQK_MULMAT
// For now we do not implement sinks in the iqk FA implementation
if (iqk_flash_attn_noalibi(q->type, mask->type, max_bias,
if (iqk_flash_attn_noalibi(q->type, mask ? mask->type : GGML_TYPE_F16, max_bias,
q->ne[3], q->ne[2], q->nb[3], q->nb[2],
k->ne[3], k->ne[2], k->nb[3], k->nb[2],
v->ne[3], v->ne[2], v->nb[3], v->nb[2],
dst->ne[2], dst->ne[1], dst->nb[1],
k->type, v->type,
Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1],
q->data, k->data, v->data, mask->data, sinks ? sinks->data : NULL,
Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask ? mask->nb[1] : 0,
q->data, k->data, v->data, mask ? mask->data : NULL, sinks ? sinks->data : NULL,
scale, softcap, (float *)dst->data,
params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth, dst->op_params[4])) return;
printf("iqk_flash_attn_noalibi returned false for Dk = %ld, Dv = %ld, mask = %p:\n", Dk, Dv, (const void *)mask);
printf(" q(%s): %ld x %ld x %ld x %ld\n", ggml_type_name(q->type), q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
printf(" k(%s): %ld x %ld x %ld x %ld\n", ggml_type_name(k->type), k->ne[0], k->ne[1], k->ne[2], k->ne[3]);
printf(" v(%s): %ld x %ld x %ld x %ld\n", ggml_type_name(v->type), v->ne[0], v->ne[1], v->ne[2], v->ne[3]);

// if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
// //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n",
Expand Down
12 changes: 6 additions & 6 deletions ggml/src/iqk/iqk_flash_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
GGML_ABORT("Fatal error");
}

if (n_swa > 0) {
if (n_swa > 0 && mask) {
constexpr int kMinBatch = 256;
int ntokens = std::max(kMinBatch, neq1);
int nblock = (ntokens + n_swa + kMinBatch - 1)/kMinBatch;
Expand All @@ -203,7 +203,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
int rv3 = neq3/nev3;

int first_k = 0, last_k = nek1;
if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256) {
if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256 && mask) {
// This is a quick hack for SWA models.
// Given that the mask is the same for all layers, ideally we should determine the
// cache bounds once, and reuse for the whole graph. But even with this simple hack
Expand Down Expand Up @@ -271,7 +271,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
auto kth = (const char *)k + kv_offset*stride_k;
auto vth = (const char *)v + kv_offset*stride_v;
auto qth = (const char *)q;
auto mth = (const char *)mask + kv_offset*sizeof(uint16_t); // we don't have ggml_half available here
auto mth = mask ? (const char *)mask + kv_offset*sizeof(uint16_t) : nullptr; // we don't have ggml_half available here

auto work = (char *)work_buffer;
auto size_thread = (Dv + 16)*rk2*sizeof(float);
Expand Down Expand Up @@ -322,7 +322,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
auto vth = (const char *)v + ith_k*(nek1/gcd_k)*stride_v;
auto q_offset = ith_q < ith_mid ? ith_q*nq_per_thread*nbq2 : (ith_mid*nq_per_thread + (ith_q - ith_mid)*nq_this_thread)*nbq2;
auto qth = (const char *)q + q_offset;
auto mth = (const char *)mask + ith_k*(nek1/gcd_k)*sizeof(uint16_t); // we don't have ggml_half available here
auto mth = mask ? (const char *)mask + ith_k*(nek1/gcd_k)*sizeof(uint16_t) : nullptr; // we don't have ggml_half available here

// Each thread will produce a result of size Dv*nq_this_thread*sizeof(float)
// In addition, we need M, S for the nq_this_thread rows the thread is processing
Expand Down Expand Up @@ -403,7 +403,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2);
auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2;
auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2;
auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here
auto this_m = mask ? (const char *)mask + ik01*sizeof(uint16_t) : nullptr; // we don't have ggml_half available here
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
Dk, Dv, rk2, this_nk, nbq2, stride_k, stride_v, 0, Dv,
this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m, nullptr, 0,
Expand Down Expand Up @@ -473,7 +473,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
(const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q),
(const void *)((const char *)k + iq2/rk2*nbk2 + iq3/rk3*nbk3),
(const void *)((const char *)v + iq2/rv2*nbv2 + iq3/rv3*nbv3),
(const void *)((const char *)mask + iq1*stride_m), sinksf, 1,
mask ? (const void *)((const char *)mask + iq1*stride_m) : nullptr, sinksf, 1,
scale, softcap,
(float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false;
}
Expand Down
Loading