diff --git a/common/common.cpp b/common/common.cpp index 0427f21ca..d4237e331 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1181,6 +1181,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.mmproj_use_gpu = false; return true; } + if (arg == "--mtmd-kq-type") { + CHECK_ARG + params.mtmd_kq_type = argv[i]; + return true; + } if (arg == "--image" || arg == "--audio") { CHECK_ARG params.image.emplace_back(argv[i]); @@ -2489,9 +2494,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "multi-modality" }); options.push_back({ "*", " --mmproj FILE", "path to a multimodal projector file for LLaVA. see examples/llava/README.md" }); options.push_back({ "*", " --image FILE", "path to an image file. use with multimodal models. Specify multiple times for batching" }); - options.push_back({ "*", " --image-min-tokens N", "minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)"}); - options.push_back({ "*", " --image-max-tokens N", "maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)" }); - options.push_back({ "*", " --no-context-shift", "disable context-shift." }); + options.push_back({ "*", " --image-min-tokens N", "minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)"}); + options.push_back({ "*", " --image-max-tokens N", "maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)" }); + options.push_back({ "*", " --mtmd-kq-type TYPE", "data type for multimodality K*Q (default: %s)", params.mtmd_kq_type.c_str() }); + options.push_back({ "*", " --no-context-shift", "disable context-shift." }); options.push_back({ "*", "--context-shift (auto|on|off|0|1)", "set context-shift (default: %s)", params.ctx_shift ? "on" : "off" }); options.push_back({ "backend" }); options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" }); diff --git a/common/common.h b/common/common.h index 1a5ab59a2..faaa36122 100644 --- a/common/common.h +++ b/common/common.h @@ -378,6 +378,7 @@ struct gpt_params { std::vector image; // path to image file(s) int image_min_tokens = -1; int image_max_tokens = -1; + std::string mtmd_kq_type = "f32"; // embedding bool embedding = false; // get only sentence embedding diff --git a/examples/mtmd/clip.cpp b/examples/mtmd/clip.cpp index 90d7f91c1..946c93fc9 100644 --- a/examples/mtmd/clip.cpp +++ b/examples/mtmd/clip.cpp @@ -443,6 +443,7 @@ struct clip_ctx { int max_nodes = 8192; ggml_backend_sched_ptr sched; clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO; + ggml_type kq_type = GGML_TYPE_F32; // for debugging bool debug_graph = false; @@ -450,6 +451,7 @@ struct clip_ctx { clip_ctx(clip_context_params & ctx_params) { flash_attn_type = ctx_params.flash_attn_type; + kq_type = ctx_params.kq_type; debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr; //backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); backend_cpu = ggml_backend_cpu_init(); @@ -1011,7 +1013,9 @@ struct clip_graph { // self-attention { cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); + cb(cur, "qkv_w", il); cur = ggml_add(ctx0, cur, layer.qkv_b); + cb(cur, "qkv_b", il); ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float), cur->nb[1], 0); @@ -1072,12 +1076,14 @@ struct clip_graph { nullptr, nullptr, layer.deepstack_fc2_w, layer.deepstack_fc2_b, ffn_op_type::FFN_GELU, il); + cb(feat, "ffn_feat", il); if(!deepstack_features) { deepstack_features = feat; } else { // concat along the feature dimension deepstack_features = ggml_concat(ctx0, deepstack_features, feat, 0); + cb(deepstack_features, "feat_concat", il); } } @@ -1098,9 +1104,11 @@ struct clip_graph { nullptr, nullptr, model.mm_1_w, model.mm_1_b, ffn_op_type::FFN_GELU, -1); + cb(embeddings, "ffn_postl", -1); if (deepstack_features) { embeddings = ggml_concat(ctx0, embeddings, deepstack_features, 0); // concat along the feature dimension + cb(embeddings, "ffn_postl_concat", -1); } // build the graph @@ -2425,6 +2433,22 @@ struct clip_graph { ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3); v = ggml_cont(ctx0, v); + if (ctx->kq_type != k->type) { + auto bs = ggml_blck_size(ctx->kq_type); + if (k->ne[0] % bs != 0) { + int nbs = bs*((k->ne[0] + bs - 1)/bs); + k = ggml_pad(ctx0, k, nbs - k->ne[0], 0, 0, 0); + } + if (q->ne[0] % bs != 0) { + int nbs = bs*((q->ne[0] + bs - 1)/bs); + q = ggml_pad(ctx0, q, nbs - q->ne[0], 0, 0, 0); + } + k = ggml_cast(ctx0, k, ctx->kq_type); + if (!ggml_is_quantized(ctx->kq_type)) { + q = ggml_cast(ctx0, q, ctx->kq_type); + } + } + if (q->ne[3] == 1 && q->ne[2] > 1 && q->ne[2] == k->ne[2] && q->ne[2] == v->ne[2] && q->ne[1]*k->ne[1]*q->ne[2]/1024./1024. >= 256.) { cur = nullptr; for (int64_t i2 = 0; i2 < q->ne[2]; ++i2) { @@ -2432,10 +2456,14 @@ struct clip_graph { auto ki = ggml_view_2d(ctx0, k, k->ne[0], k->ne[1], k->nb[1], k->nb[2]*i2); auto vi = ggml_view_2d(ctx0, v, v->ne[0], v->ne[1], v->nb[1], v->nb[2]*i2); auto kq = ggml_mul_mat(ctx0, ki, qi); + cb(kq, "kq_i", il); kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f); + cb(kq, "softmax(kq_i)", il); auto kqv = ggml_mul_mat(ctx0, vi, kq); + cb(kqv, "kqv_i", il); if (cur) { cur = ggml_concat(ctx0, cur, kqv, 0); + cb(cur, "kqv_i_concat", il); } else { cur = kqv; } diff --git a/examples/mtmd/clip.h b/examples/mtmd/clip.h index 1099370bb..54559dea8 100644 --- a/examples/mtmd/clip.h +++ b/examples/mtmd/clip.h @@ -35,6 +35,7 @@ struct clip_context_params { enum clip_flash_attn_type flash_attn_type; int image_min_tokens; int image_max_tokens; + ggml_type kq_type; }; struct clip_init_result { diff --git a/examples/mtmd/mtmd-cli.cpp b/examples/mtmd/mtmd-cli.cpp index cfd09fb1f..55a9a14c2 100644 --- a/examples/mtmd/mtmd-cli.cpp +++ b/examples/mtmd/mtmd-cli.cpp @@ -99,6 +99,22 @@ void common_init() { #endif // ======================= end compat ================================ +static ggml_type ggml_type_from_str(const std::string & s) { + if (s == "f32") { + return GGML_TYPE_F32; + } + if (s == "f16") { + return GGML_TYPE_F16; + } + if (s == "bf16") { + return GGML_TYPE_BF16; + } + if (s == "q8_0") { + return GGML_TYPE_Q8_0; + } + throw std::runtime_error("Invalid cache type: " + s); +} + struct mtmd_cli_context { mtmd::context_ptr ctx_vision; common_init_result llama_init; @@ -171,6 +187,7 @@ struct mtmd_cli_context { mparams.flash_attn_type = params.flash_attn ? LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED; mparams.image_min_tokens = params.image_min_tokens; mparams.image_max_tokens = params.image_max_tokens; + mparams.kq_type = ggml_type_from_str(params.mtmd_kq_type); ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams)); if (!ctx_vision.get()) { LOG_ERR("Failed to load vision model from %s\n", clip_path); diff --git a/examples/mtmd/mtmd.cpp b/examples/mtmd/mtmd.cpp index 657bed58c..77a0bbab8 100644 --- a/examples/mtmd/mtmd.cpp +++ b/examples/mtmd/mtmd.cpp @@ -102,6 +102,7 @@ mtmd_context_params mtmd_context_params_default() { /* flash_attn_type */ LLAMA_FLASH_ATTN_TYPE_AUTO, /* image_min_tokens */ -1, /* image_max_tokens */ -1, + /* kq_type */ GGML_TYPE_F32, }; return params; } @@ -170,6 +171,7 @@ struct mtmd_context { /* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_DISABLED, /* image_min_tokens */ ctx_params.image_min_tokens, /* image_max_tokens */ ctx_params.image_max_tokens, + /* kq_type */ ctx_params.kq_type, }; auto res = clip_init(mmproj_fname, ctx_clip_params); diff --git a/examples/mtmd/mtmd.h b/examples/mtmd/mtmd.h index a76f29b84..85084abc2 100644 --- a/examples/mtmd/mtmd.h +++ b/examples/mtmd/mtmd.h @@ -87,6 +87,7 @@ struct mtmd_context_params { // limit number of image tokens, only for vision models with dynamic resolution int image_min_tokens; // minimum number of tokens for image input (default: read from metadata) int image_max_tokens; // maximum number of tokens for image input (default: read from metadata) + ggml_type kq_type; }; MTMD_API const char * mtmd_default_marker(void); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index a08579076..b7a4a14e4 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3971,12 +3971,15 @@ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, int i = 0; ggml_float sum = 0; #if defined(__AVX512F__) && defined(__AVX512DQ__) + __m512 vsum = _mm512_setzero_ps(); for (; i + 15 < n; i += 16) { __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i), _mm512_set1_ps(max))); _mm512_storeu_ps(y + i, val); - sum += (ggml_float)_mm512_reduce_add_ps(val); + vsum = _mm512_add_ps(vsum, val); + //sum += (ggml_float)_mm512_reduce_add_ps(val); } + sum = (ggml_float)_mm512_reduce_add_ps(vsum); #elif defined(__AVX2__) && defined(__FMA__) for (; i + 7 < n; i += 8) { __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i), @@ -21667,16 +21670,20 @@ static void ggml_compute_forward_flash_attn_ext_f16( #if GGML_USE_IQK_MULMAT // For now we do not implement sinks in the iqk FA implementation - if (iqk_flash_attn_noalibi(q->type, mask->type, max_bias, + if (iqk_flash_attn_noalibi(q->type, mask ? mask->type : GGML_TYPE_F16, max_bias, q->ne[3], q->ne[2], q->nb[3], q->nb[2], k->ne[3], k->ne[2], k->nb[3], k->nb[2], v->ne[3], v->ne[2], v->nb[3], v->nb[2], dst->ne[2], dst->ne[1], dst->nb[1], k->type, v->type, - Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], - q->data, k->data, v->data, mask->data, sinks ? sinks->data : NULL, + Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask ? mask->nb[1] : 0, + q->data, k->data, v->data, mask ? mask->data : NULL, sinks ? sinks->data : NULL, scale, softcap, (float *)dst->data, params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth, dst->op_params[4])) return; + printf("iqk_flash_attn_noalibi returned false for Dk = %ld, Dv = %ld, mask = %p:\n", Dk, Dv, (const void *)mask); + printf(" q(%s): %ld x %ld x %ld x %ld\n", ggml_type_name(q->type), q->ne[0], q->ne[1], q->ne[2], q->ne[3]); + printf(" k(%s): %ld x %ld x %ld x %ld\n", ggml_type_name(k->type), k->ne[0], k->ne[1], k->ne[2], k->ne[3]); + printf(" v(%s): %ld x %ld x %ld x %ld\n", ggml_type_name(v->type), v->ne[0], v->ne[1], v->ne[2], v->ne[3]); // if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) { // //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n", diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index aa01237d9..f30c2b80f 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -184,7 +184,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float GGML_ABORT("Fatal error"); } - if (n_swa > 0) { + if (n_swa > 0 && mask) { constexpr int kMinBatch = 256; int ntokens = std::max(kMinBatch, neq1); int nblock = (ntokens + n_swa + kMinBatch - 1)/kMinBatch; @@ -203,7 +203,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float int rv3 = neq3/nev3; int first_k = 0, last_k = nek1; - if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256) { + if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256 && mask) { // This is a quick hack for SWA models. // Given that the mask is the same for all layers, ideally we should determine the // cache bounds once, and reuse for the whole graph. But even with this simple hack @@ -271,7 +271,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float auto kth = (const char *)k + kv_offset*stride_k; auto vth = (const char *)v + kv_offset*stride_v; auto qth = (const char *)q; - auto mth = (const char *)mask + kv_offset*sizeof(uint16_t); // we don't have ggml_half available here + auto mth = mask ? (const char *)mask + kv_offset*sizeof(uint16_t) : nullptr; // we don't have ggml_half available here auto work = (char *)work_buffer; auto size_thread = (Dv + 16)*rk2*sizeof(float); @@ -322,7 +322,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float auto vth = (const char *)v + ith_k*(nek1/gcd_k)*stride_v; auto q_offset = ith_q < ith_mid ? ith_q*nq_per_thread*nbq2 : (ith_mid*nq_per_thread + (ith_q - ith_mid)*nq_this_thread)*nbq2; auto qth = (const char *)q + q_offset; - auto mth = (const char *)mask + ith_k*(nek1/gcd_k)*sizeof(uint16_t); // we don't have ggml_half available here + auto mth = mask ? (const char *)mask + ith_k*(nek1/gcd_k)*sizeof(uint16_t) : nullptr; // we don't have ggml_half available here // Each thread will produce a result of size Dv*nq_this_thread*sizeof(float) // In addition, we need M, S for the nq_this_thread rows the thread is processing @@ -403,7 +403,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2); auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2; auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2; - auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here + auto this_m = mask ? (const char *)mask + ik01*sizeof(uint16_t) : nullptr; // we don't have ggml_half available here if (!iqk_flash_attn_impl(int_type_k, int_type_v, Dk, Dv, rk2, this_nk, nbq2, stride_k, stride_v, 0, Dv, this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m, nullptr, 0, @@ -473,7 +473,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float (const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q), (const void *)((const char *)k + iq2/rk2*nbk2 + iq3/rk3*nbk3), (const void *)((const char *)v + iq2/rv2*nbv2 + iq3/rv3*nbv3), - (const void *)((const char *)mask + iq1*stride_m), sinksf, 1, + mask ? (const void *)((const char *)mask + iq1*stride_m) : nullptr, sinksf, 1, scale, softcap, (float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false; } diff --git a/ggml/src/iqk/iqk_gemm_floats.cpp b/ggml/src/iqk/iqk_gemm_floats.cpp index 4d0eba8d0..a2a3c3c67 100644 --- a/ggml/src/iqk/iqk_gemm_floats.cpp +++ b/ggml/src/iqk/iqk_gemm_floats.cpp @@ -420,6 +420,32 @@ template struct QFTBF16 final : public QFBaseBF16 { IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); } const ggml_bf16_t * y[nrc]; }; +struct QFBaseBF16x8 { + constexpr static int k_step = 16; + using Data = __m256bh; + using Acc = __m256; + static inline Data load(const ggml_bf16_t * x) { return __m256bh(_mm256_loadu_si256((const __m256i *)x)); } + static inline Acc acc(Acc prev, Data y, Data x) { + return _mm256_dpbf16_ps(prev, y, x); + } + static inline Acc acc_first(const Data& y, const Data& x) { + return _mm256_dpbf16_ps(_mm256_setzero_ps(), y, x); + } + static inline float hsum(Acc acc) { + return hsum_float_8(acc); + } +}; +template struct QFTBF16x8 final : public QFBaseBF16x8 { + constexpr static int nrc = nrc_in; + QFTBF16x8(const DataInfo& info) { + for (int iy = 0; iy < nrc; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy); + } + QFTBF16x8(const char * cx, size_t bx) { + for (int iy = 0; iy < nrc; ++iy) y[iy] = (const ggml_bf16_t *)(cx + iy*bx); + } + IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); } + const ggml_bf16_t * y[nrc]; +}; template IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) { @@ -476,6 +502,61 @@ void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, in case 4: mul_mat_Qx_Qy_MxN(n, cx, bx, last_x, info); break; } } +template +IQK_NOINLINE void mul_mat_Qx_Qy_MxNx8(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) { + int nb = n/QFBaseBF16x8::k_step; + QFTBF16x8 y(info); + QFTBF16x8 x(cx + ix0*bx, bx); + QFBaseBF16x8::Data xv[nrc_x]; + QFBaseBF16x8::Acc acc[nrc_x*nrc_y]; + auto yv = y.load1(0, 0); + for (int ix = 0; ix < nrc_x; ++ix) { + xv[ix] = x.load1(ix, 0); + acc[ix] = QFBaseBF16x8::acc_first(yv, xv[ix]); + } + for (int iy = 1; iy < nrc_y; ++iy) { + yv = y.load1(iy, 0); + for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QFBaseBF16x8::acc_first(yv, xv[ix]); + } + for (int i = 1; i < nb; ++i) { + yv = y.load1(0, i); + for (int ix = 0; ix < nrc_x; ++ix) { + xv[ix] = x.load1(ix, i); + acc[ix] = QFBaseBF16x8::acc(acc[ix], yv, xv[ix]); + } + for (int iy = 1; iy < nrc_y; ++iy) { + yv = y.load1(iy, i); + for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QFBaseBF16x8::acc(acc[nrc_x*iy + ix], yv, xv[ix]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QFBaseBF16x8::hsum(acc[nrc_x*iy+ix])); +} + +template +void mul_mat_fX_fY_Tx8(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + constexpr int k_nx = nrc_y <= 2 ? 8 : 5; + const char * cx = (const char *)vx; + for (int ix = 0; ix < nrc_x/k_nx; ++ix) { + mul_mat_Qx_Qy_MxNx8(n, cx, bx, ix*k_nx, info); + } + int last_x = k_nx*(nrc_x/k_nx); + if (last_x == nrc_x) return; + int nx = nrc_x - last_x; + if constexpr (nrc_y <= 2) { + if (nx >= 4) { + mul_mat_Qx_Qy_MxNx8(n, cx, bx, last_x, info); + last_x += 4; + if (last_x == nrc_x) return; + nx = nrc_x - last_x; + } + } + switch (nx) { + case 1: mul_mat_Qx_Qy_MxNx8(n, cx, bx, last_x, info); break; + case 2: mul_mat_Qx_Qy_MxNx8(n, cx, bx, last_x, info); break; + case 3: mul_mat_Qx_Qy_MxNx8(n, cx, bx, last_x, info); break; + case 4: mul_mat_Qx_Qy_MxNx8(n, cx, bx, last_x, info); break; + } +} #endif @@ -501,6 +582,14 @@ void set_mul_mat_bf16(std::array& funcs) { funcs[3] = mul_mat_fX_fY_T<4>; funcs[4] = mul_mat_fX_fY_T<5>; } +void set_mul_mat_bf16x8(std::array& funcs) { + for (auto& f : funcs) f = nullptr; + funcs[0] = mul_mat_fX_fY_Tx8<1>; + funcs[1] = mul_mat_fX_fY_Tx8<2>; + funcs[2] = mul_mat_fX_fY_Tx8<3>; + funcs[3] = mul_mat_fX_fY_Tx8<4>; + funcs[4] = mul_mat_fX_fY_Tx8<5>; +} void set_mul_mat_bf16_r16(std::array& funcs) { for (auto& f : funcs) f = nullptr; funcs[0] = mul_mat_bf16_r16_bf16<1>; @@ -519,10 +608,16 @@ void set_mul_mat_bf16_r16(std::array& funcs) { bool iqk_set_kernels_float(int ne00, int typeA, int typeB, std::array& kernels) { if (typeA == GGML_TYPE_BF16) { - if (ne00 % 16) return false; + if (ne00 % 8) return false; switch (typeB) { #ifdef __AVX512BF16__ - case GGML_TYPE_BF16: set_mul_mat_bf16(kernels); break; + case GGML_TYPE_BF16: { + if (ne00 % 16 == 0) { + set_mul_mat_bf16(kernels); + } else { + set_mul_mat_bf16x8(kernels); + } + } break; #else case GGML_TYPE_BF16: set_mul_mat_f(kernels); break; case GGML_TYPE_F32: set_mul_mat_f(kernels); break;